GPipe Module

class torchgpipe.GPipe(module, balance, \**kwargs)

Wraps an arbitrary nn.Sequential module to train on GPipe. If the module requires lots of memory, GPipe will be very efficient.

model = nn.Sequential(a, b, c, d)
model = GPipe(model, balance=[1, 1, 1, 1], chunks=8)
output = model(input)

GPipe combines pipeline parallelism with checkpointing to reduce peak memory required to train while minimizing device under-utilization.

You should determine the balance when defining a GPipe module, as balancing will not be done automatically. The module will be partitioned into multiple devices according to the given balance. You may rely on heuristics to find your own optimal configuration.

  • module (torch.nn.Sequential) – sequential module to be parallelized

  • balance (ints) – list of number of layers in each partition

Keyword Arguments
  • devices (iterable of devices) – devices to use (default: all CUDA devices)

  • chunks (int) – number of micro-batches (default: 1)

  • checkpoint (str) – when to enable checkpointing, one of 'always', 'except_last', or 'never' (default: 'except_last')

  • deferred_batch_norm (bool) – whether to use deferred BatchNorm moving statistics (default: False, see Deferred Batch Normalization for more details)


GPipe is a fairly transparent module wrapper. It doesn’t modify the input and output signature of the underlying module. But there’s type restriction. Input and output have to be a Tensor or a tuple of tensors. This restriction is applied at partition boundaries too.


input (torch.Tensor or tensors) – input mini-batch


output mini-batch

Return type

tensor or tensors


TypeError – input is not a tensor or tensors.


The number of layers in each partition.


The devices mapped to each partition.

devices[-1] refers to the device of the last partition, which means it is the output device. Probably, you need to use it to transfer the target to calculate the loss without a device mismatch RuntimeError. For example:

out_device = gpipe.devices[-1]

for input, target in loader:
    target = target.to(out_device, non_blocking=True)
    output = gpipe(input)
    loss = F.cross_entropy(output, target)

The number of micro-batches.


The checkpoint mode to determine when to enable checkpointing. It is one of 'always', 'except_last', or 'never'.

Skip Connections

@torchgpipe.skip.skippable([stash][, pop])

The decorator to define a nn.Module with skip connections. Decorated modules are called “skippable”. This functionality works perfectly fine even when the module is not wrapped by GPipe.

Each skip tensor is managed by its name. Before manipulating skip tensors, a skippable module must statically declare the names for skip tensors by stash and/or pop parameters. Skip tensors with pre-declared name can be stashed by yield stash(name, tensor) or popped by tensor = yield pop(name).

Here is an example with three layers. A skip tensor named “1to3” is stashed and popped at the first and last layer, respectively:

class Layer1(nn.Module):
    def forward(self, input):
        yield stash('1to3', input)
        return f1(input)

class Layer2(nn.Module):
    def forward(self, input):
        return f2(input)

class Layer3(nn.Module):
    def forward(self, input):
        skip_1to3 = yield pop('1to3')
        return f3(input) + skip_1to3

model = nn.Sequential(Layer1(), Layer2(), Layer3())

One skippable module can stash or pop multiple skip tensors:

@skippable(stash=['alice', 'bob'], pop=['carol'])
class StashStashPop(nn.Module):
    def forward(self, input):
        yield stash('alice', f_alice(input))
        yield stash('bob', f_bob(input))
        carol = yield pop('carol')
        return input + carol

Every skip tensor must be associated with exactly one pair of stash and pop. GPipe checks this restriction automatically when wrapping a module. You can also check the restriction by verify_skippables() without GPipe.


@skippable changes the type of the wrapped class. But currently (mypy v0.740), mypy could not understand class decorators yet (#3135).

There are two workarounds:

  1. Naively ignore type errors by # type: ignore.

  2. Use skippable()() as a function instead of a decorator.

Skippable.isolate(ns[, only=names])

Isolates a specified subset or the whole set of skip tensors into a namespace. In a single sequential module, skip tensors with the same name are not allowed unless they are isolated by different namespaces.

Here’s an example using the same name for skip tensors twice. Each pair of Layer1 and Layer2 is isolated with its own namespace ns1 and ns2. There is no conflict anymore:

ns1 = Namespace()
ns2 = Namespace()

model = nn.Sequential(

When only parameter is omitted, all skip tensors are isolated. You can isolate a subset of skip tensors by passing only parameter:

ns_alice = Namespace()
ns_bob = Namespace()

model = nn.Sequential(
    StashStashPop().isolate(ns_alice, only=['alice']) \
                   .isolate(ns_bob, only=['bob']),

ns (Namespace) – namespace for isolation

Keyword Arguments

only (iterable of strs) – names of specific skip tensors to be isolated (omit this option to isolate all skip tensors declared in this module)


this module itself

torchgpipe.skip.stash(name, tensor)

The command to stash a skip tensor.

def forward(self, input):
    yield stash('name', input)
    return f(input)
  • name (str) – name of skip tensor

  • input (torch.Tensor or None) – tensor to pass to the skip connection


The command to pop a skip tensor.

def forward(self, input):
    skip = yield pop('name')
    return f(input) + skip

name (str) – name of skip tensor


the skip tensor previously stashed by another layer under the same name

class torchgpipe.skip.Namespace

Namespace for isolating skip tensors used by isolate().


Verifies if the underlying skippable modules satisfy integrity.

Every skip tensor must have only one pair of stash and pop. If there are one or more unmatched pairs, it will raise TypeError with the detailed messages.

Here are a few failure cases. verify_skippables() will report failure for these cases:

# Layer1 stashes "1to3".
# Layer3 pops "1to3".

nn.Sequential(Layer1(), Layer2())
#               └──── ?

nn.Sequential(Layer2(), Layer3())
#                   ? ────┘

nn.Sequential(Layer1(), Layer2(), Layer3(), Layer3())
#               └───────────────────┘       ^^^^^^

nn.Sequential(Layer1(), Layer1(), Layer2(), Layer3())
#             ^^^^^^      └───────────────────┘

To use the same name for multiple skip tensors, they must be isolated by different namespaces. See isolate().


TypeError – one or more pairs of stash and pop are not matched.

Inspecting GPipe Timeline


Whether the current forward propagation is under checkpointing.


True if it’s under checkpointing.

Return type



Whether the current forward propagation is under checkpoint recomputation. Use this to prevent duplicated side-effects at forward propagation:

class Counter(nn.Module):
    def __init__(self):
        self.counter = 0

    def forward(self, input):
        if not is_recomputing():
            self.counter += 1
        return input

True if it’s under checkpoint recomputation.

Return type


Automatic Balancing

torchgpipe.balance.balance_by_time(partitions, module, sample, timeout=1.0, device=torch.device('cuda'))

Naive automatic balancing by elapsed time per layer.

sample = torch.empty(128, 3, 224, 224)
balance = balance_by_time(torch.cuda.device_count(), model, sample)
gpipe = GPipe(model, balance, chunks=8)
  • partitions (int) – intended number of partitions

  • module (torch.nn.Sequential) – sequential module to be partitioned

  • sample (torch.Tensor) – example input with arbitrary batch size

Keyword Arguments
  • timeout (float) – profiling iterates again if the timeout (in second) is not exceeded (default: 1.0)

  • device ('cpu' or 'cuda' device) – CPU or CUDA device where each layer is profiled (default: the current CUDA device)


A list of number of layers in each partition. Use it for the balance parameter of GPipe.


module and sample must be placed on the same device.

torchgpipe.balance.balance_by_size(partitions, module, input, chunks=1, param_scale=2.0, device=torch.device('cuda'))

Naive automatic balancing by CUDA memory usage per layer.

During training, required memory for parameters depends on which optimizer is used. Optimizers may use buffers for each parameter to track optimization statistics internally, such as momentum buffer in SGD.

To get more reliable size based balance, you should specify param_scale with regard to your optimizer. The default param_scale is 2 instead of 1 due to gradient accumulation which is necessary for every optimizer.

Follow this guide to choose correct param_scale for typical optimizers:



Internal State






exp_avg, exp_avg_sq, (max_exp_avg_sq)



square_avg, acc_delta






square_avg, (momentum_buffer), (grad_avg)

Here’s a simple example with the Adam optimizer:

balance = balance_by_size(

    # Same size with mini-batch to train
    torch.empty(1024, 3, 224, 224),

    # Number of micro-batches to train with GPipe

    # 4 for Adam

gpipe = GPipe(model, balance, chunks=8)
adam = Adam(gpipe.parameters())
  • partitions (int) – intended number of partitions

  • module (torch.nn.Sequential) – sequential module to be partitioned

  • input (torch.Tensor) – example mini-batch with the same size to train

Keyword Arguments
  • chunks (int) – number of micro-batches will be used to train (default: 1)

  • param_scale (float) – how many copies of parameters would be allocated for training. It depends on optimizer. See the above guide. (default: 2.0)

  • device ('cuda' device) – CUDA device where each layer is profiled (default: the current CUDA device)


A list of number of layers in each partition. Use it for the balance parameter of GPipe.


module and input must be placed on the same CUDA device.