This tutorial is available as a Jupyter notebook.

Open in Colab

๐Ÿ”Œ Training with TPUs#

Composer provides beta support for single core training on TPUs. We integrate with the torch_xla backend. For installation instructions and more details, see here.

Tutorial Goals and Concepts Covered#

The goal of this tutorial is to show you the steps needed to do Composer training on TPUs. Concretely, weโ€™ll train a ResNet-20 on CIFAR10 using a single TPU core.

The training setup is exactly the same as with any other device, except the model must be moved to the device before passing to our Trainer, where we must also specify device=tpu to enable the trainer to use TPUs. Weโ€™ll touch on these steps below.

Letโ€™s get started!


As prerequisites, first install torch_xla and the latest Composer version.

[ ]:
%pip install mosaicml
# To install from source instead of the last release, comment the command above and uncomment the following one.
# %pip install git+https://github.com/mosaicml/composer.git

%pip install cloud-tpu-client==0.10 torch==1.12.0 https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-1.12-cp37-cp37m-linux_x86_64.whl

# To install from source instead of the last release, comment the command above and uncomment the following one.
# %pip install 'mosaicml @ git+https://github.com/mosaicml/composer.git'"

from composer import Trainer
from composer.models import ComposerClassifier



Next, we define the model and optimizer. TPUs require the model to be moved to the device before the optimizer is created, which we do here.

[ ]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_xla.core.xla_model as xm

class Block(nn.Module):
    """A ResNet block."""

    def __init__(self, f_in: int, f_out: int, downsample: bool = False):
        super(Block, self).__init__()

        stride = 2 if downsample else 1
        self.conv1 = nn.Conv2d(f_in, f_out, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(f_out)
        self.conv2 = nn.Conv2d(f_out, f_out, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(f_out)
        self.relu = nn.ReLU(inplace=True)

        # No parameters for shortcut connections.
        if downsample or f_in != f_out:
            self.shortcut = nn.Sequential(
                nn.Conv2d(f_in, f_out, kernel_size=1, stride=2, bias=False),
            self.shortcut = nn.Sequential()

    def forward(self, x: torch.Tensor):
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        return self.relu(out)

class ResNetCIFAR(nn.Module):
    """A residual neural network as originally designed for CIFAR-10."""

    def __init__(self, outputs: int = 10):
        super(ResNetCIFAR, self).__init__()

        depth = 20
        width = 16
        num_blocks = (depth - 2) // 6

        plan = [(width, num_blocks), (2 * width, num_blocks), (4 * width, num_blocks)]

        self.num_classes = outputs

        # Initial convolution.
        current_filters = plan[0][0]
        self.conv = nn.Conv2d(3, current_filters, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(current_filters)
        self.relu = nn.ReLU(inplace=True)

        # The subsequent blocks of the ResNet.
        blocks = []
        for segment_index, (filters, num_blocks) in enumerate(plan):
            for block_index in range(num_blocks):
                downsample = segment_index > 0 and block_index == 0
                blocks.append(Block(current_filters, filters, downsample))
                current_filters = filters

        self.blocks = nn.Sequential(*blocks)

        # Final fc layer. Size = number of filters in last segment.
        self.fc = nn.Linear(plan[-1][0], outputs)
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x: torch.Tensor):
        out = self.relu(self.bn(self.conv(x)))
        out = self.blocks(out)
        out = F.avg_pool2d(out, out.size()[3])
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

model = ComposerClassifier(module=ResNetCIFAR(), num_classes=10)

model = model.to(xm.xla_device())

optimizer = torch.optim.SGD(


Creating the CIFAR10 dataset and dataloaders are exactly the same as with other non-TPU devices.

[ ]:
from torchvision import datasets, transforms

data_directory = "./data"

# Normalization constants
mean = (0.507, 0.487, 0.441)
std = (0.267, 0.256, 0.276)

batch_size = 1024

cifar10_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])

train_dataset = datasets.CIFAR10(data_directory, train=True, download=True, transform=cifar10_transforms)
test_dataset = datasets.CIFAR10(data_directory, train=False, download=True, transform=cifar10_transforms)

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)


Lastly, we train for 20 epochs on the TPU by simply adding device='tpu' as an argument to the Trainer.

Note: we currently only support single-core TPUs in this beta release. Future releases will include multi-core TPU support.

[ ]:
trainer = Trainer(


What next?#

Youโ€™ve now seen a simple example of how to use the Composer trainer on a TPU. Cool!

To get to know Composer more, please continue to explore our tutorials! Hereโ€™s a couple suggestions:

Come get involved with MosaicML!#

Weโ€™d love for you to get involved with the MosaicML community in any of these ways:

Star Composer on GitHub#

Help make others aware of our work by starring Composer on GitHub.

Join the MosaicML Slack#

Head on over to the MosaicML slack to join other ML efficiency enthusiasts. Come for the paper discussions, stay for the memes!

Contribute to Composer#

Is there a bug you noticed or a feature youโ€™d like? File an issue or make a pull request!