ComposerModel#
- class composer.ComposerModel[source]#
The interface needed to make a PyTorch model compatible with
composer.Trainer
.To create a
Trainer
-compatible model, subclassComposerModel
and implementforward()
andloss()
. For full functionality (logging and validation), implementget_metrics()
andeval_forward()
.See the Composer Model walk through for more details.
Minimal Example:
import torchvision import torch.nn.functional as F from composer.models import ComposerModel class ResNet18(ComposerModel): def __init__(self): super().__init__() self.model = torchvision.models.resnet18() # define PyTorch model in __init__. def forward(self, batch): # batch is the output of the dataloader # specify how batches are passed through the model inputs, _ = batch return self.model(inputs) def loss(self, outputs, batch): # pass batches and `forward` outputs to the loss _, targets = batch return F.cross_entropy(outputs, targets)
- logger#
The training
Logger
. The trainer sets theLogger
on the:attr:.Event.INIT event.- Type
Optional[Logger]
- eval_forward(batch, outputs=None)[source]#
Run the evaluation forward pass.
By default, it returns the
outputs
if they are not None. Otherwise,self(batch)
is returned.Override this method for models that require custom validation logic โ e.g. self-supervised learning.
- Parameters
batch โ The dataloader batch.
outputs (Any, optional) โ If training, the outputs from the forward pass. Otherwise, None.
- Returns
Any โ The evaluation outputs.
- abstract forward(batch)[source]#
Compute model output given a batch from the dataloader.
- Parameters
batch (Batch) โ The output batch from dataloader.
- Returns
Any โ The result that is passed to
loss()
as the parameteroutputs
.
Warning
This method is different from vanilla PyTorch
model.forward(x)
ormodel(x)
as it takes a batch of data that has to be unpacked.Example:
def forward(self, batch): # batch is the output of the dataloader inputs, _ = batch return self.model(inputs)
The outputs of
forward()
are passed toloss()
by the trainer:for batch in train_dataloader: optimizer.zero_grad() outputs = model.forward(batch) loss = model.loss(outputs, batch) loss.backward()
- get_metrics(is_train)[source]#
Get the metrics.
This method will be called by the trainer immediately after
Event.INIT
.Note
Each item in the returned dictionary will be
copy.deepcopy
before it is used. This is to ensure that each dataloader (e.g. train, eval) will be accumulating metrics separately.To share a metric across all dataloaders, wrap it with
MetricSpec(metric=metric, share=False)
.- Parameters
is_train (bool) โ Whether the training metrics or evaluation metrics should be returned.
- Returns
dict[str, Metric] โ A mapping of the metric name to a Metric.
- abstract loss(outputs, batch, *args, **kwargs)[source]#
Compute the loss of the model given
outputs
fromforward()
and aBatch
of data from the dataloader. TheTrainer
will call.backward()
on the returned loss.- Parameters
outputs (Any) โ The output of the forward pass.
batch (Batch) โ The output batch from dataloader.
- Returns
Tensor | Sequence[Tensor] โ The loss as a
torch.Tensor
.
Example:
import torch.nn.functional as F def loss(self, outputs, batch): # pass batches and :meth:`forward` outputs to the loss _, targets = batch # discard inputs from batch return F.cross_entropy(outputs, targets)
The outputs of
forward()
are passed toloss()
by the trainer:for batch in train_dataloader: optimizer.zero_grad() outputs = model.forward(batch) loss = model.loss(outputs, batch) loss.backward()
- update_metric(batch, outputs, metric)[source]#
Update the given metric.
- Args:
batch: The dataloader batch outputs: The output from
eval_forward()
metric (Metric): The metric to update.
- Returns
Optional[dict] โ Optionally return metric results to be stored in state.