API¶
GPipe Module¶
-
class
torchgpipe.
GPipe
(module, balance, **kwargs)¶ Wraps an arbitrary
Sequential
module to train on GPipe. If the module requires lots of memory, GPipe will be very efficient:model = nn.Sequential(a, b, c, d) model = GPipe(model, balance=[1, 1, 1, 1], chunks=8) output = model(input)
GPipe combines pipeline parallelism with checkpointing to reduce peak memory required to train while minimizing device under-utilization.
You should determine the balance when defining a GPipe module, as balancing will not be done automatically. The module will be partitioned into multiple devices according to the given balance. You may rely on heuristics to find your own optimal configuration.
Parameters: - module (nn.Sequential) – sequential module to be parallelized
- balance (ints) – list of number of layers in each partition
Keyword Arguments: - devices (iterable of devices) – devices to use (default: all CUDA devices)
- chunks (int) – number of micro-batches (default:
1
) - checkpoint (str) – when to enable checkpointing, one of
'always'
,'except_last'
, or'never'
(default:'except_last'
) - deferred_batch_norm (bool) – whether to use deferred BatchNorm moving statistics
(default:
False
, See Deferred BatchNorm for more details)
Raises: TypeError
– the module is not aSequential
.ValueError
– invalid arguments, or wrong balanceIndexError
– the number of devices is fewer than the number of partitions.
-
forward
(input)¶ GPipe
is a fairly transparent module wrapper. It doesn’t modify the input and output signature of the underlying module. But there’s type restriction. Input and output have to be aTensor
or a tuple of tensors. This restriction is applied at partition boundaries too.Parameters: input (tensor or tensors) – input mini-batch Returns: output mini-batch Return type: tensor or tensors Raises: TypeError
– input is not a tensor or tensors.
-
devices
¶ The devices mapped to each partition.
devices[-1]
refers to the device of the last partition, which means it is the output device. Probably, you need to use it to transfer the target to calculate the loss without a device mismatchRuntimeError
. For example:out_device = gpipe.devices[-1] for input, target in loader: target = target.to(out_device, non_blocking=True) output = gpipe(input) loss = F.cross_entropy(output, target)
Inspecting GPipe Timeline¶
-
torchgpipe.
current_microbatch
()¶ Gets the current micro-batch identifier as a tensor.
If your module relies on where the current micro-batch lane, use it to identify the lane.
Returns: A tensor which identifies the current micro-batch lane, or None
for out of a GPipe context.Return type: tensor or None
See also
-
torchgpipe.
is_recomputing
()¶ Whether if the current thread is under checkpoint recomputation.
Returns: True
if it’s under checkpoint recomputation.Return type: bool See also
Automatic Balancing¶
-
torchgpipe_balancing.
balance_by_time
(module, canary, partitions, device, timeout)¶ Balances the given seqeuntial module by elapsed time per layer.
Parameters: - module (nn.Sequential) – sequential module to be partitioned
- sample (Tensor) – example input
Keyword Arguments: - partitions (int) – intended number of partitions (default:
1
) - device (torch.device) – CUDA device where the module is profiled (default: any related CUDA
device or
torch.device('cuda')
) - timeout (float) – profiling iterates again if the timeout (in second) is not exceeded
(default:
1.0
)
Returns: A list of number of layers in each partition. Use it for the
balance
parameter ofGPipe
.
-
torchgpipe_balancing.
balance_by_size
(module, canary, partitions, device)¶ Balances the given seqeuntial module by memory usage per layer.
Note
This function relies on
torch.cuda.reset_max_memory_allocated()
which is introduced at PyTorch 1.1. Therefore, it doesn’t support neither CPU tensors nor PyTorch 1.0.x.Parameters: - module (nn.Sequential) – sequential module to be partitioned
- sample (Tensor) – example input
Keyword Arguments: - partitions (int) – intended number of partitions (default:
1
) - device (torch.device) – CUDA device where the module is profiled (default: any related CUDA
device or
torch.device('cuda')
)
Returns: A list of number of layers in each partition. Use it for the
balance
parameter ofGPipe
.