Tip

This tutorial is available as a Jupyter notebook.

Open in Colab

โšก Migrating from PTL#

PyTorch Lightning is a popular and well-designed framework for training deep neural networks. You can use Composerโ€™s algorithms in your Pytorch Lightning code via the functional API with no additional code changes.

However, if you are interested in features like automatic gradient accumulation, a clean time abstraction, and the easiest path to trying out different combinations of algorithms, you will need to switch from the PTL trainer to the Composer trainer.

The below is a quick guide on how to adapt your LightningModule to our simple interface.

Tutorial Goals and Concepts Covered#

The goal of this tutorial is to illustrate a path from working in PyTorch Lightning to working in Composer.

Weโ€™ll primarily focus on the different ways models are structured in each framework, in order to illustrate how one maps on to the other.

Letโ€™s get started!

Setup#

Weโ€™ll first install dependencies and define the data and model.

Install Dependencies#

If you havenโ€™t already, letโ€™s install Composer and PyTorch Lightning:

[ ]:
%pip install pytorch-lightning

%pip install mosaicml
# To install from source instead of the last release, comment the command above and uncomment the following one.
# %pip install git+https://github.com/mosaicml/composer.git

The Model#

In this section, weโ€™ll go through the process of migrating a Resnet-18 model from PTL to Composer. We will be following the PTL example here.

First, some relevant imports, as well as creating the model as in the PTL tutorial.

[ ]:
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models
from pytorch_lightning import LightningModule
from torch.optim.lr_scheduler import OneCycleLR

def create_model():
    model = torchvision.models.resnet18(pretrained=False, num_classes=10)
    model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    model.maxpool = nn.Identity()
    return model

Training data#

As is standard, we setup the training data for CIFAR-10 using torchvision datasets.

[ ]:
import torch
import torch.utils.data
import torchvision

transform = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ]
)

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_dataloader = torch.utils.data.DataLoader(trainset, batch_size=256, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(testset, batch_size=256, shuffle=False)

PTL Lightning Module#

Following the PTL tutorial, we use the LitResnet model:

[ ]:
from torchmetrics.functional import accuracy

class LitResnet(LightningModule):
    def __init__(self, lr=0.05):
        super().__init__()
        self.save_hyperparameters()
        self.model = create_model()

    def forward(self, x):
        out = self.model(x)
        return F.log_softmax(out, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        self.log("train_loss", loss)
        return loss

    def evaluate(self, batch, stage=None):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)

        if stage:
            self.log(f"{stage}_loss", loss, prog_bar=True)
            self.log(f"{stage}_acc", acc, prog_bar=True)

    def validation_step(self, batch, batch_idx):
        self.evaluate(batch, "val")

    def test_step(self, batch, batch_idx):
        self.evaluate(batch, "test")

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(
            self.model.parameters(),
            lr=self.hparams.lr,
            momentum=0.9,
            weight_decay=5e-4,
        )
        steps_per_epoch = 45000 // 256
        scheduler_dict = {
            "scheduler": OneCycleLR(
                optimizer,
                0.1,
                epochs=30,
                steps_per_epoch=steps_per_epoch,
            ),
            "interval": "step",
        }
        return {"optimizer": optimizer, "lr_scheduler": scheduler_dict}

PTLModel = LitResnet(lr=0.05)

LitModel to Composer#

Notice that up to here, we have only used PyTorch Lightning code. Here we will modify the PTL module to be compatible with Composer. There are a few major differences:

  • The training_step is broken into two parts, the forward and the loss methods. This is needed since some algorithms (such as label smoothing or selective backprop) need to intercept and modify the loss.

  • Optimizers and schedulers are passed directly to the Trainer during initialization.

  • Our forward step accepts the entire batch as input and has to take care of unpacking the batch.

For more information about the ComposerModel format, see our documentation.

[ ]:
from torchmetrics.classification import MulticlassAccuracy
from composer.models.base import ComposerModel
PTLmodel = LitResnet(lr=0.05)

class MosaicResnet(ComposerModel):
    def __init__(self):
        super().__init__()
        self.model = create_model()
        self.acc = MulticlassAccuracy(num_classes=10, average='micro')

    def loss(self, outputs, batch, *args, **kwargs):
        """Accepts the outputs from forward() and the batch"""
        x, y = batch  # unpack the labels
        return F.nll_loss(outputs, y)

    def get_metrics(self, is_train):
        return {'MulticlassAccuracy': self.acc}

    def forward(self, batch):
        x, _ = batch
        y = self.model(x)
        return F.log_softmax(y, dim=1)

    def eval_forward(self, batch, outputs = None):
        return outputs if outputs is not None else self.forward(batch)

    def update_metric(self, batch, outputs, metric) -> None:
        _, targets = batch
        metric.update(outputs, targets)

Training#

We instantiate the Composer trainer similarly by specifying the model, dataloaders, optimizers, and max_duration (epochs). For more details on the trainer arguments, see our Using the Trainer guide.

Now you are ready to insert your algorithms! As an example, here we add the BlurPool algorithm.

[ ]:
from composer import Trainer
from composer.algorithms import BlurPool

model = MosaicResnet()
optimizer = torch.optim.SGD(
    model.parameters(),
    lr=0.05,
    momentum=0.9,
    weight_decay=5e-4,
)

steps_per_epoch = 45000 // 256

scheduler = OneCycleLR(
    optimizer,
    0.1,
    epochs=30,
    steps_per_epoch=steps_per_epoch,
)

trainer = Trainer(
    model=model,
    algorithms=[
        BlurPool(
            replace_convs=True,
            replace_maxpools=True,
            blur_first=True
        ),
    ],
    train_dataloader=train_dataloader,
    device="gpu" if torch.cuda.is_available() else "cpu",
    eval_dataloader=test_dataloader,
    optimizers=optimizer,
    schedulers=scheduler,
    step_schedulers_every_batch=True,  # interval should be step
    max_duration='2ep',
    eval_interval=1,
    train_subset_num_batches=1,
)
trainer.fit()

What next?#

Hopefully this tutorial provides you with some useful intuitions for making the jump from PyTorch Lightning to Composer.

To continue learning about Composer, check out our guide to using the trainer and explore more of our tutorials! Here are a couple suggestions:

Come get involved with MosaicML!#

Weโ€™d love for you to get involved with the MosaicML community in any of these ways:

Star Composer on GitHub#

Help make others aware of our work by starring Composer on GitHub.

Join the MosaicML Slack#

Head on over to the MosaicML slack to join other ML efficiency enthusiasts. Come for the paper discussions, stay for the memes!

Contribute to Composer#

Is there a bug you noticed or a feature youโ€™d like? File an issue or make a pull request!