API¶
GPipe Module¶
-
class
torchgpipe.
GPipe
(module, balance, \**kwargs)¶ Wraps an arbitrary
nn.Sequential
module to train on GPipe. If the module requires lots of memory, GPipe will be very efficient.model = nn.Sequential(a, b, c, d) model = GPipe(model, balance=[1, 1, 1, 1], chunks=8) output = model(input)
GPipe combines pipeline parallelism with checkpointing to reduce peak memory required to train while minimizing device under-utilization.
You should determine the balance when defining a
GPipe
module, as balancing will not be done automatically. The module will be partitioned into multiple devices according to the given balance. You may rely on heuristics to find your own optimal configuration.- Parameters
module (torch.nn.Sequential) – sequential module to be parallelized
balance (ints) – list of number of layers in each partition
- Keyword Arguments
devices (iterable of devices) – devices to use (default: all CUDA devices)
chunks (int) – number of micro-batches (default:
1
)checkpoint (str) – when to enable checkpointing, one of
'always'
,'except_last'
, or'never'
(default:'except_last'
)deferred_batch_norm (bool) – whether to use deferred BatchNorm moving statistics (default:
False
, see Deferred Batch Normalization for more details)
- Raises
TypeError – the module is not a
nn.Sequential
.ValueError – invalid arguments, or wrong balance
IndexError – the number of devices is fewer than the number of partitions.
-
forward
(input)¶ GPipe
is a fairly transparent module wrapper. It doesn’t modify the input and output signature of the underlying module. But there’s type restriction. Input and output have to be aTensor
or a tuple of tensors. This restriction is applied at partition boundaries too.- Parameters
input (torch.Tensor or tensors) – input mini-batch
- Returns
output mini-batch
- Return type
tensor or tensors
- Raises
TypeError – input is not a tensor or tensors.
-
balance
¶ The number of layers in each partition.
-
devices
¶ The devices mapped to each partition.
devices[-1]
refers to the device of the last partition, which means it is the output device. Probably, you need to use it to transfer the target to calculate the loss without a device mismatchRuntimeError
. For example:out_device = gpipe.devices[-1] for input, target in loader: target = target.to(out_device, non_blocking=True) output = gpipe(input) loss = F.cross_entropy(output, target)
-
chunks
¶ The number of micro-batches.
-
checkpoint
¶ The checkpoint mode to determine when to enable checkpointing. It is one of
'always'
,'except_last'
, or'never'
.
Skip Connections¶
-
@
torchgpipe.skip.
skippable
([stash][, pop])¶ The decorator to define a
nn.Module
with skip connections. Decorated modules are called “skippable”. This functionality works perfectly fine even when the module is not wrapped byGPipe
.Each skip tensor is managed by its name. Before manipulating skip tensors, a skippable module must statically declare the names for skip tensors by stash and/or pop parameters. Skip tensors with pre-declared name can be stashed by
yield stash(name, tensor)
or popped bytensor = yield pop(name)
.Here is an example with three layers. A skip tensor named “1to3” is stashed and popped at the first and last layer, respectively:
@skippable(stash=['1to3']) class Layer1(nn.Module): def forward(self, input): yield stash('1to3', input) return f1(input) class Layer2(nn.Module): def forward(self, input): return f2(input) @skippable(pop=['1to3']) class Layer3(nn.Module): def forward(self, input): skip_1to3 = yield pop('1to3') return f3(input) + skip_1to3 model = nn.Sequential(Layer1(), Layer2(), Layer3())
One skippable module can stash or pop multiple skip tensors:
@skippable(stash=['alice', 'bob'], pop=['carol']) class StashStashPop(nn.Module): def forward(self, input): yield stash('alice', f_alice(input)) yield stash('bob', f_bob(input)) carol = yield pop('carol') return input + carol
Every skip tensor must be associated with exactly one pair of stash and pop.
GPipe
checks this restriction automatically when wrapping a module. You can also check the restriction byverify_skippables()
withoutGPipe
.Note
@skippable
changes the type of the wrapped class. But currently (mypy v0.740), mypy could not understand class decorators yet (#3135).There are two workarounds:
Naively ignore type errors by
# type: ignore
.Use
skippable()()
as a function instead of a decorator.
See also
-
Skippable.
isolate
(ns[, only=names])¶ Isolates a specified subset or the whole set of skip tensors into a namespace. In a single sequential module, skip tensors with the same name are not allowed unless they are isolated by different namespaces.
Here’s an example using the same name for skip tensors twice. Each pair of
Layer1
andLayer2
is isolated with its own namespacens1
andns2
. There is no conflict anymore:ns1 = Namespace() ns2 = Namespace() model = nn.Sequential( Layer1().isolate(ns1), Layer1().isolate(ns2), Layer2(), Layer3().isolate(ns2), Layer3().isolate(ns1), )
When only parameter is omitted, all skip tensors are isolated. You can isolate a subset of skip tensors by passing only parameter:
ns_alice = Namespace() ns_bob = Namespace() model = nn.Sequential( ... StashStashPop().isolate(ns_alice, only=['alice']) \ .isolate(ns_bob, only=['bob']), ... )
- Parameters
ns (Namespace) – namespace for isolation
- Keyword Arguments
only (iterable of strs) – names of specific skip tensors to be isolated (omit this option to isolate all skip tensors declared in this module)
- Returns
this module itself
-
torchgpipe.skip.
stash
(name, tensor)¶ The command to stash a skip tensor.
def forward(self, input): yield stash('name', input) return f(input)
- Parameters
name (str) – name of skip tensor
input (torch.Tensor or None) – tensor to pass to the skip connection
-
torchgpipe.skip.
pop
(name)¶ The command to pop a skip tensor.
def forward(self, input): skip = yield pop('name') return f(input) + skip
- Parameters
name (str) – name of skip tensor
- Returns
the skip tensor previously stashed by another layer under the same name
-
torchgpipe.skip.
verify_skippables
(module)¶ Verifies if the underlying skippable modules satisfy integrity.
Every skip tensor must have only one pair of stash and pop. If there are one or more unmatched pairs, it will raise
TypeError
with the detailed messages.Here are a few failure cases.
verify_skippables()
will report failure for these cases:# Layer1 stashes "1to3". # Layer3 pops "1to3". nn.Sequential(Layer1(), Layer2()) # └──── ? nn.Sequential(Layer2(), Layer3()) # ? ────┘ nn.Sequential(Layer1(), Layer2(), Layer3(), Layer3()) # └───────────────────┘ ^^^^^^ nn.Sequential(Layer1(), Layer1(), Layer2(), Layer3()) # ^^^^^^ └───────────────────┘
To use the same name for multiple skip tensors, they must be isolated by different namespaces. See
isolate()
.- Raises
TypeError – one or more pairs of stash and pop are not matched.
Inspecting GPipe Timeline¶
-
torchgpipe.
is_checkpointing
()¶ Whether the current forward propagation is under checkpointing.
-
torchgpipe.
is_recomputing
()¶ Whether the current forward propagation is under checkpoint recomputation. Use this to prevent duplicated side-effects at forward propagation:
class Counter(nn.Module): def __init__(self): super().__init__() self.counter = 0 def forward(self, input): if not is_recomputing(): self.counter += 1 return input
See also
Automatic Balancing¶
-
torchgpipe.balance.
balance_by_time
(partitions, module, sample, timeout=1.0, device=torch.device('cuda'))¶ Naive automatic balancing by elapsed time per layer.
sample = torch.empty(128, 3, 224, 224) balance = balance_by_time(torch.cuda.device_count(), model, sample) gpipe = GPipe(model, balance, chunks=8)
- Parameters
partitions (int) – intended number of partitions
module (torch.nn.Sequential) – sequential module to be partitioned
sample (torch.Tensor) – example input with arbitrary batch size
- Keyword Arguments
timeout (float) – profiling iterates again if the timeout (in second) is not exceeded (default:
1.0
)device ('cpu' or 'cuda' device) – CPU or CUDA device where each layer is profiled (default: the current CUDA device)
- Returns
A list of number of layers in each partition. Use it for the balance parameter of
GPipe
.
Note
module and sample must be placed on the same device.
-
torchgpipe.balance.
balance_by_size
(partitions, module, input, chunks=1, param_scale=2.0, device=torch.device('cuda'))¶ Naive automatic balancing by CUDA memory usage per layer.
During training, required memory for parameters depends on which optimizer is used. Optimizers may use buffers for each parameter to track optimization statistics internally, such as momentum buffer in SGD.
To get more reliable size based balance, you should specify param_scale with regard to your optimizer. The default param_scale is 2 instead of 1 due to gradient accumulation which is necessary for every optimizer.
Follow this guide to choose correct param_scale for typical optimizers:
Optimizer
param_scale
Internal State
SGD
2–3
(momentum_buffer)
Adam
4–5
exp_avg, exp_avg_sq, (max_exp_avg_sq)
Adadelta
4
square_avg, acc_delta
Adagrad
3
sum
RMSprop
3–5
square_avg, (momentum_buffer), (grad_avg)
Here’s a simple example with the Adam optimizer:
balance = balance_by_size( torch.cuda.device_count(), model, # Same size with mini-batch to train torch.empty(1024, 3, 224, 224), # Number of micro-batches to train with GPipe chunks=8, # 4 for Adam param_scale=4.0, ) gpipe = GPipe(model, balance, chunks=8) adam = Adam(gpipe.parameters())
- Parameters
partitions (int) – intended number of partitions
module (torch.nn.Sequential) – sequential module to be partitioned
input (torch.Tensor) – example mini-batch with the same size to train
- Keyword Arguments
chunks (int) – number of micro-batches will be used to train (default:
1
)param_scale (float) – how many copies of parameters would be allocated for training. It depends on optimizer. See the above guide. (default:
2.0
)device ('cuda' device) – CUDA device where each layer is profiled (default: the current CUDA device)
- Returns
A list of number of layers in each partition. Use it for the balance parameter of
GPipe
.
Note
module and input must be placed on the same CUDA device.