API

GPipe Module

class torchgpipe.GPipe(module, balance, **kwargs)

Wraps an arbitrary Sequential module to train on GPipe. If the module requires lots of memory, GPipe will be very efficient:

model = nn.Sequential(a, b, c, d)
model = GPipe(model, balance=[1, 1, 1, 1], chunks=8)
output = model(input)

GPipe combines pipeline parallelism with checkpointing to reduce peak memory required to train while minimizing device under-utilization.

You should determine the balance when defining a GPipe module, as balancing will not be done automatically. The module will be partitioned into multiple devices according to the given balance. You may rely on heuristics to find your own optimal configuration.

Parameters:
  • module (nn.Sequential) – sequential module to be parallelized
  • balance (ints) – list of number of layers in each partition
Keyword Arguments:
 
  • devices (iterable of devices) – devices to use (default: all CUDA devices)
  • chunks (int) – number of micro-batches (default: 1)
  • checkpoint (str) – when to enable checkpointing, one of 'always', 'except_last', or 'never' (default: 'except_last')
  • deferred_batch_norm (bool) – whether to use deferred BatchNorm moving statistics (default: False, See Deferred BatchNorm for more details)
Raises:
  • TypeError – the module is not a Sequential.
  • ValueError – invalid arguments, or wrong balance
  • IndexError – the number of devices is fewer than the number of partitions.
forward(input)

GPipe is a fairly transparent module wrapper. It doesn’t modify the input and output signature of the underlying module. But there’s type restriction. Input and output have to be a Tensor or a tuple of tensors. This restriction is applied at partition boundaries too.

Parameters:input (tensor or tensors) – input mini-batch
Returns:output mini-batch
Return type:tensor or tensors
Raises:TypeError – input is not a tensor or tensors.
devices

The devices mapped to each partition.

devices[-1] refers to the device of the last partition, which means it is the output device. Probably, you need to use it to transfer the target to calculate the loss without a device mismatch RuntimeError. For example:

out_device = gpipe.devices[-1]

for input, target in loader:
    target = target.to(out_device, non_blocking=True)
    output = gpipe(input)
    loss = F.cross_entropy(output, target)

Inspecting GPipe Timeline

torchgpipe.current_microbatch()

Gets the current micro-batch identifier as a tensor.

If your module relies on where the current micro-batch lane, use it to identify the lane.

Returns:A tensor which identifies the current micro-batch lane, or None for out of a GPipe context.
Return type:tensor or None
torchgpipe.is_recomputing()

Whether if the current thread is under checkpoint recomputation.

Returns:True if it’s under checkpoint recomputation.
Return type:bool

Automatic Balancing

torchgpipe_balancing.balance_by_time(module, canary, partitions, device, timeout)

Balances the given seqeuntial module by elapsed time per layer.

Parameters:
  • module (nn.Sequential) – sequential module to be partitioned
  • sample (Tensor) – example input
Keyword Arguments:
 
  • partitions (int) – intended number of partitions (default: 1)
  • device (torch.device) – CUDA device where the module is profiled (default: any related CUDA device or torch.device('cuda'))
  • timeout (float) – profiling iterates again if the timeout (in second) is not exceeded (default: 1.0)
Returns:

A list of number of layers in each partition. Use it for the balance parameter of GPipe.

torchgpipe_balancing.balance_by_size(module, canary, partitions, device)

Balances the given seqeuntial module by memory usage per layer.

Note

This function relies on torch.cuda.reset_max_memory_allocated() which is introduced at PyTorch 1.1. Therefore, it doesn’t support neither CPU tensors nor PyTorch 1.0.x.

Parameters:
  • module (nn.Sequential) – sequential module to be partitioned
  • sample (Tensor) – example input
Keyword Arguments:
 
  • partitions (int) – intended number of partitions (default: 1)
  • device (torch.device) – CUDA device where the module is profiled (default: any related CUDA device or torch.device('cuda'))
Returns:

A list of number of layers in each partition. Use it for the balance parameter of GPipe.