composer.models.base#
The ComposerModel base interface.
Classes
The interface needed to make a PyTorch model compatible with |
- class composer.models.base.ComposerModel[source]#
Bases:
torch.nn.modules.module.Module
,abc.ABC
The interface needed to make a PyTorch model compatible with
composer.Trainer
.To create a
Trainer
-compatible model, subclassComposerModel
and implementforward()
andloss()
. For full functionality (logging and validation), implementmetrics()
andvalidate()
.See the Composer Model walk through for more details.
Minimal Example:
import torchvision import torch.nn.functional as F from composer.models import ComposerModel class ResNet18(ComposerModel): def __init__(self): super().__init__() self.model = torchvision.models.resnet18() # define PyTorch model in __init__. def forward(self, batch): # batch is the output of the dataloader # specify how batches are passed through the model inputs, _ = batch return self.model(inputs) def loss(self, outputs, batch): # pass batches and `forward` outputs to the loss _, targets = batch return F.cross_entropy(outputs, targets)
- abstract forward(batch)[source]#
Compute model output given a batch from the dataloader.
- Parameters
batch (Batch) โ The output batch from dataloader.
- Returns
Tensor | Sequence[Tensor] โ The result that is passed to
loss()
as the parameteroutputs
.
Warning
This method is different from vanilla PyTorch
model.forward(x)
ormodel(x)
as it takes a batch of data that has to be unpacked.Example:
def forward(self, batch): # batch is the output of the dataloader inputs, _ = batch return self.model(inputs)
The outputs of
forward()
are passed toloss()
by the trainer:for batch in train_dataloader: optimizer.zero_grad() outputs = model.forward(batch) loss = model.loss(outputs, batch) loss.backward()
- abstract loss(outputs, batch, *args, **kwargs)[source]#
Compute the loss of the model given
outputs
fromforward()
and aBatch
of data from the dataloader. TheTrainer
will call.backward()
on the returned loss.- Parameters
outputs (Any) โ The output of the forward pass.
batch (Batch) โ The output batch from dataloader.
- Returns
Tensor | Sequence[Tensor] โ The loss as a
torch.Tensor
.
Example:
import torch.nn.functional as F def loss(self, outputs, batch): # pass batches and :meth:`forward` outputs to the loss _, targets = batch # discard inputs from batch return F.cross_entropy(outputs, targets)
The outputs of
forward()
are passed toloss()
by the trainer:for batch in train_dataloader: optimizer.zero_grad() outputs = model.forward(batch) loss = model.loss(outputs, batch) loss.backward()
- metrics(train=False)[source]#
Get metrics for evaluating the model. Metrics should be instances of
torchmetrics.Metric
defined in__init__()
. This format enables accurate distributed logging. Metrics consume the outputs ofvalidate()
. To track multiple metrics, return a list of metrics in a MetricCollection.- Parameters
train (bool, optional) โ True to return metrics that should be computed during training and False otherwise. This flag is set automatically by the
Trainer
. Default:False
.- Returns
Metric or MetricCollection โ An instance of
Metric
or MetricCollection.
Warning
Each metric keeps states which are updated with data seen so far. As a result, different metric instances should be used for training and validation. See: https://torchmetrics.readthedocs.io/en/latest/pages/overview.html for more details.
Example:
from torchmetrics.classification import Accuracy from composer.models.loss import CrossEntropyLoss def __init__(self): super().__init__() self.train_acc = Accuracy() # torchmetric self.val_acc = Accuracy() self.val_loss = CrossEntropyLoss() def metrics(self, train: bool = False): return self.train_acc if train else MetricCollection([self.val_acc, self.val_loss])
- validate(batch)[source]#
Compute model outputs on provided data. Will be called by the trainer with
torch.no_grad
enabled.The output of this function will be directly used as input to all metrics returned by
metrics()
.- Parameters
batch (Batch) โ The output batch from dataloader
- Returns
Tuple[Any, Any] โ A Tuple of (
outputs
,targets
) that is passed directly to theupdate()
methods of the metrics returned bymetrics()
.
Example:
def validate(self, batch): # batch is the output of the dataloader inputs, targets = batch outputs = self.model(inputs) return outputs, targets # return a tuple of (outputs, targets)
This pseudocode illustrates how
validate()
outputs are passed tometrics()
:metrics = model.metrics(train=False) # get torchmetrics for batch in val_dataloader: outputs, targets = model.validate(batch) metrics.update(outputs, targets) # update metrics with output, targets for each batch metrics.compute() # compute final metrics