from torchgpipe import GPipe model = nn.Sequential(a, b, c, d) model = GPipe(model, balance=[1, 1, 1, 1], chunks=8) for input in data_loader: output = model(input)
What is GPipe?¶
GPipe is a scalable pipeline parallelism library published by Google Brain, which allows efficient training of large, memory-consuming models. According to the paper, GPipe can train a 25x larger model by using 8x devices (TPU), and train a model 3.5x faster by using 4x devices.
Google trained AmoebaNet-B with 557M parameters over GPipe. This model has achieved 84.3% top-1 and 97.0% top-5 accuracy on ImageNet classification benchmark (the state-of-the-art performance as of May 2019).
- Understanding GPipe
- User Guide