๐ Training with TPUs#
Composer provides beta support for single core training on TPUs. We integrate with the torch_xla
backend. For installation instructions and more details, see here.
Recommended Background#
This tutorial is pretty straightforward. It uses the same basic training cycle set up in the Getting Started tutorial, which you might want to check out first if you havenโt already.
Tutorial Goals and Concepts Covered#
The goal of this tutorial is to show you the steps needed to do Composer training on TPUs. Concretely, weโll train a ResNet-20 on CIFAR10 using a single TPU core.
The training setup is exactly the same as with any other device, except the model must be moved to the device before passing to our Trainer
, where we must also specify device=tpu
to enable the trainer to use TPUs. Weโll touch on these steps below.
Letโs get started!
Prerequisites#
As prerequisites, first install torch_xla
and the latest Composer version.
[ ]:
%pip install mosaicml
# To install from source instead of the last release, comment the command above and uncomment the following one.
# %pip install git+https://github.com/mosaicml/composer.git
%pip install cloud-tpu-client==0.10 torch==1.12.0 https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-1.12-cp37-cp37m-linux_x86_64.whl
# To install from source instead of the last release, comment the command above and uncomment the following one.
# %pip install 'mosaicml @ git+https://github.com/mosaicml/composer.git'"
from composer import Trainer
from composer import models
Setup#
Model#
Next, we define the model and optimizer. TPUs require the model to be moved to the device before the optimizer is created, which we do here.
[ ]:
import torch
import torch_xla.core.xla_model as xm
model = models.composer_resnet_cifar(model_name='resnet_20', num_classes=10)
model = model.to(xm.xla_device())
optimizer = torch.optim.SGD(
model.parameters(),
lr=0.02,
momentum=0.9)
Datasets#
Creating the CIFAR10 dataset and dataloaders are exactly the same as with other non-TPU devices.
[ ]:
from torchvision import datasets, transforms
data_directory = "../data"
# Normalization constants
mean = (0.507, 0.487, 0.441)
std = (0.267, 0.256, 0.276)
batch_size = 1024
cifar10_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
train_dataset = datasets.CIFAR10(data_directory, train=True, download=True, transform=cifar10_transforms)
test_dataset = datasets.CIFAR10(data_directory, train=False, download=True, transform=cifar10_transforms)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
Training#
Lastly, we train for 20 epochs on the TPU by simply adding device='tpu'
as an argument to the Trainer.
Note: we currently only support single-core TPUs in this beta release. Future releases will include multi-core TPU support.
[ ]:
trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
device="tpu",
eval_dataloader=test_dataloader,
optimizers=optimizer,
max_duration='20ep',
eval_interval=1,
)
trainer.fit()
What next?#
Youโve now seen a simple example of how to use the Composer trainer on a TPU. Cool!
To get to know Composer more, please continue to explore our tutorials! Hereโs a couple suggestions:
Continue learning about other Composer features like automatic gradient accumulation and automatic restarting from checkpoints
Explore more advanced applications of Composer like applying image segmentation to medical images or fine-tuning a transformer for sentiment classification.
Keep it custom with our custom speedups tutorial.
Come get involved with MosaicML!#
Weโd love for you to get involved with the MosaicML community in any of these ways:
Star Composer on GitHub#
Help make others aware of our work by starring Composer on GitHub.
Join the MosaicML Slack#
Head on over to the MosaicML slack to join other ML efficiency enthusiasts. Come for the paper discussions, stay for the memes!
Contribute to Composer#
Is there a bug you noticed or a feature youโd like? File an issue or make a pull request!