๐Ÿ›ป ComposerModel#

Your Pytorch model and training step must be re-organized as a ComposerModel in order to use our Trainer. This interface helps our trainer access the necessary parts of your model to easily speed up training.

Using your own Model#

To create your own model, define the forward() and loss() methods. Here is an example with a trainable torchvision ResNet-18 classifier with cross-entropy loss.

Notice how the forward pass is still under user control (no magic here!) and encapsulated together clearly within the architecture.

The trainer takes care of:

  • x.to(device), y.to(device)

  • loss.backward()

  • optimizer.zero_grad()

  • optimizer.step()

As well as other features such as distributed training, numerics, and gradient accumulation.

import torchvision
import torch.nn.functional as F

from composer.models import ComposerModel

class ResNet18(ComposerModel):

    def __init__(self):
        self.model = torchvision.models.resnet18()

    def forward(self, batch): # batch is the output of the dataloader
        # specify how batches are passed through the model
        inputs, _ = batch
        return self.model(inputs)

    def loss(self, outputs, batch):
        # pass batches and `forward` outputs to the loss
        _, targets = batch
        return F.cross_entropy(outputs, targets)

The Composer model can then be passed to our trainer.

import torch.optim as optim
from composer import Trainer

model = ResNet18()
optimizer = optim.Adam(model.parameters(), lr=0.0001)
train_dataloader # standard pytorch dataloader

trainer = Trainer(

Both the forward() and loss() methods are passed the batch directly from the dataloader. We leave the unpacking of that batch into inputs and targets to the user since it can vary depending on the task.

We also provide several common classes for various tasks, specifically:


Users from other frameworks such as pytorch lightning may be used to defining a training_step method which groups the forward and loss together. However, many of our algorithmic methods (such as label smoothing or selective backprop) need to intercept and modify the loss. For this reason, we split it into two separate methods.

By convention, we define our PyTorch layers in the self.model attribute of ComposerModel. We encourage this pattern because it makes it easier to extract the underlying model for inference when training is completed. However, this is not enforced, and users can configure the layers directly in the class if they prefer.


To compute metrics during training, implement the following methods:

def eval_forward(self, batch, outputs) -> outputs:

def get_metrics(self, is_train=False) -> dict[str, Metric]:

def update_metric(self, batch, outputs, metric) -> None:

where Metrics should be compatible with the torchmetrics.Metrics protocol. We require that the output of ComposerModel.eval_forward() be consumable by that protocol. Specifically, the validation loop does something like this:

metrics = model.get_metrics(is_train=False)

for batch in val_dataloader:
    outputs = model.eval_forward(batch)
    for m in metrics.values():
        model.update_metric(batch, outputs, m)

for metric in metrics.values():

A full example of a validation implementation would be:

class ComposerClassifier(ComposerModel):

    def __init__(self):
        self.model = torchvision.models.resnet18()
        self.train_accuracy = torchmetrics.classification.MulticlassAccuracy(num_classes=1000, average='micro')
        self.val_accuracy = torchmetrics.classification.MulticlassAccuracy(num_classes=1000, average='micro')


    def eval_forward(self, batch, outputs):
        if outputs:
            return outputs
        inputs, _ = batch
        outputs = self.model(inputs)
        return outputs

    def update_metric(self, batch, outputs, metric):
        _, targets = batch
        metric.update(outputs, targets)

    def get_metrics(self, is_train=False):
        # defines which metrics to use in each phase of training
        return {'MulticlassAccuracy': self.train_accuracy} if train else {'MulticlassAccuracy': self.val_accuracy}


No need to set model.eval() or torch.no_grad() โ€” we take care of that in our trainer. torchmetrics also handles metrics logging when using distributed training.

Logging Results#

The trainer automatically logs the results of the metrics and the loss using all of the loggers specified by the user. For example, to log the results to a dict, use the InMemoryLogger.

See also

Our guide to Logging.

Multiple Metrics#

To run multiple metrics, wrap them in a torchmetrics.MetricCollection.

from torchmetrics.collections import MetricCollection

model.train_metrics = MetricCollection([self.train_loss, self.train_accuracy])
model.eval_metrics = MetricCollection([self.val_loss, self.val_accuracy])


We use all the metrics provided to the validation dataset. If you have multiple eval datasets and different metrics, we recommend using Evaluator (see Evaluation)


BERT Example with ๐Ÿค— Transformers#

In this example, we create a BERT model loaded from ๐Ÿค— Transformers and make it compatible with our trainer.

from transformers import AutoModelForSequenceClassification
from torchmetrics.classification import MulticlassAccuracy
from torchmetrics.collections import MetricCollection

from composer.models import HuggingFaceModel
from composer.metrics import LanguageCrossEntropy

# huggingface model
model = AutoModelForSequenceClassification.from_pretrained(

# list of torchmetrics
metrics = [LanguageCrossEntropy(), MulticlassAccuracy(num_classes=2, average='micro')]

# composer model, ready to be passed to our trainer
composer_model = HuggingFaceModel(model, metrics=metrics)

YOLOX Example with MMDetection#

In this example, we create a YOLO model loaded from MMDetection and make it compatible with our trainer.

from mmdet.models import build_detector
from mmcv import ConfigDict
from composer.models import MMDetModel

# yolox config from https://github.com/open-mmlab/mmdetection/blob/master/configs/yolox/yolox_s_8x8_300e_coco.py
yolox_s_config = dict(
    input_size=(640, 640),
    random_size_range=(15, 25),
    backbone=dict(type='CSPDarknet', deepen_factor=0.33, widen_factor=0.5),
    neck=dict(type='YOLOXPAFPN', in_channels=[128, 256, 512], out_channels=128, num_csp_blocks=1),
    bbox_head=dict(type='YOLOXHead', num_classes=80, in_channels=128, feat_channels=128),
    train_cfg=dict(assigner=dict(type='SimOTAAssigner', center_radius=2.5)),
    test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.65)))

yolox = build_detector(ConfigDict(yolox_s_config))
model = MMDetModel(yolox)