Your Pytorch model and training step must be re-organized as a
ComposerModel in order to use our
This interface helps our trainer access the necessary parts of your model
to easily speed up training.
Using your own Model#
Notice how the forward pass is still under user control (no magic here!) and encapsulated together clearly within the architecture.
The trainer takes care of:
As well as other features such as distributed training, numerics, and gradient accumulation.
import torchvision import torch.nn.functional as F from composer.models import ComposerModel class ResNet18(ComposerModel): def __init__(self): super().__init__() self.model = torchvision.models.resnet18() def forward(self, batch): # batch is the output of the dataloader # specify how batches are passed through the model inputs, _ = batch return self.model(inputs) def loss(self, outputs, batch): # pass batches and `forward` outputs to the loss _, targets = batch return F.cross_entropy(outputs, targets)
The Composer model can then be passed to our trainer.
import torch.optim as optim from composer import Trainer model = ResNet18() optimizer = optim.Adam(model.parameters(), lr=0.0001) train_dataloader # standard pytorch dataloader trainer = Trainer( model=model, optimizers=optimizer, train_dataloader=train_dataloader, max_duration='10ep' ) trainer.fit()
loss() methods are passed the
from the dataloader. We leave the unpacking of that batch into inputs and targets
to the user since it can vary depending on the task.
We also provide several common classes for various tasks, specifically:
ComposerClassifier- classification tasks with a cross entropy loss and accuracy metric.
TIMM- creates classification models from the popular TIMM library.
Users from other frameworks such as pytorch lightning may be used to
training_step method which groups the forward and loss
together. However, many of our algorithmic methods (such as
label smoothing or selective backprop) need to intercept and modify the
loss. For this reason, we split it into two separate methods.
By convention, we define our PyTorch layers in the
ComposerModel. We encourage this pattern because
it makes it easier to extract the underlying model for inference when training is
completed. However, this is not enforced, and users can configure the
layers directly in the class if they prefer.
To compute metrics during training, implement the following methods:
def validate (self, batch) -> outputs, targets: ... def metrics(self, train=False) -> Metrics: ...
Metrics should be compatible with the
torchmetrics package. We
require that the output of
ComposerModel.validate() be consumable by
torchmetrics. Specifically, the validation loop does something like this:
metrics = model.metrics(train=False) for batch in val_dataloader: outputs, targets = model.validate(batch) metrics.update(outputs, targets) # implements the torchmetrics interface metrics.compute()
A full example of a validation implementation would be:
class ComposerClassifier(ComposerModel): def __init__(self): super().__init__() self.model = torchvision.models.resnet18() self.train_accuracy = torchmetrics.Accuracy() self.val_accuracy = torchmetrics.Accuracy() ... def validate(self, batch): inputs, targets = batch outputs = self.model(inputs) return outputs, targets def metrics(self, train=False): # defines which metrics to use in each phase of training return self.train_accuracy if train else self.val_accuracy
No need to set
torch.no_grad() — we take care
of that in our trainer.
torchmetrics also handles metrics logging
when using distributed training.
The trainer automatically logs the results of the metrics and the loss
using all of the
loggers specified by the user. For example, to log
the results to a
dict, use the
Our guide to Logging.
To run multiple metrics, wrap them in a
from torchmetrics.collections import MetricCollection def metrics(self, train: bool = False) -> Metrics: if train: return MetricCollection([self.train_loss, self.train_accuracy]) return MetricCollection([self.val_loss, self.val_accuracy])
Integrate with your favorite TIMM models with our
from composer.models import Timm timm_model = Timm(model_name='resnet50', pretrained=True)
BERT Example with 🤗 Transformers#
In this example, we create a BERT model loaded from 🤗 Transformers and make it compatible with our trainer.
from transformers import AutoModelForSequenceClassification from torchmetrics import Accuracy from torchmetrics.collections import MetricCollection from composer import ComposerModel from composer.models.nlp_metrics import LanguageCrossEntropyLoss class ComposerBERT(ComposerModel): def __init__(self, num_labels): super().__init__() # huggingface model self.model = AutoModelForSequenceClassification.from_pretrained( 'bert-base-uncased', num_labels=num_labels ) # Metrics self.train_loss = LanguageCrossEntropyLoss() self.val_loss = LanguageCrossEntropyLoss() self.train_acc = Accuracy() self.val_acc = Accuracy() def forward(self, batch): outputs = self.model(**batch) return outputs def loss(self, outputs, batch): return outputs['loss'] # huggingface models output a dictionary def validate(self, batch): labels = batch.pop('labels') output = self.forward(batch) output = output['logits'] return output, labels def metrics(self, train: bool = False): if train: return MetricCollection([self.train_loss, self.train_acc]) return MetricCollection([self.val_loss, self.val_acc])