ComposerModel#
- class composer.ComposerModel[source]#
- The interface needed to make a PyTorch model compatible with - composer.Trainer.- To create a - Trainer-compatible model, subclass- ComposerModeland implement- forward()and- loss(). For full functionality (logging and validation), implement- get_metrics()and- eval_forward().- See the Composer Model walk through for more details. - Minimal Example: - import torchvision import torch.nn.functional as F from composer.models import ComposerModel class ResNet18(ComposerModel): def __init__(self): super().__init__() self.model = torchvision.models.resnet18() # define PyTorch model in __init__. def forward(self, batch): # batch is the output of the dataloader # specify how batches are passed through the model inputs, _ = batch return self.model(inputs) def loss(self, outputs, batch): # pass batches and `forward` outputs to the loss _, targets = batch return F.cross_entropy(outputs, targets) - logger#
- The training - Logger. The trainer sets the- Loggeron the:attr:.Event.INIT event.- Type
- Optional[Logger] 
 
 - eval_forward(batch, outputs=None)[source]#
- Run the evaluation forward pass. - By default, it returns the - outputsif they are not None. Otherwise,- self(batch)is returned.- Override this method for models that require custom validation logic โ e.g. self-supervised learning. - Parameters
- batch โ The dataloader batch. 
- outputs (Any, optional) โ If training, the outputs from the forward pass. Otherwise, None. 
 
- Returns
- Any โ The evaluation outputs. 
 
 - abstract forward(batch)[source]#
- Compute model output given a batch from the dataloader. - Parameters
- batch (Batch) โ The output batch from dataloader. 
- Returns
- Tensor | Sequence[Tensor] โ The result that is passed to - loss()as the parameter- outputs.
 - Warning - This method is different from vanilla PyTorch - model.forward(x)or- model(x)as it takes a batch of data that has to be unpacked.- Example: - def forward(self, batch): # batch is the output of the dataloader inputs, _ = batch return self.model(inputs) - The outputs of - forward()are passed to- loss()by the trainer:- for batch in train_dataloader: optimizer.zero_grad() outputs = model.forward(batch) loss = model.loss(outputs, batch) loss.backward() 
 - get_metrics(is_train)[source]#
- Get the metrics. - This method will be called by the trainer immediately after - Event.INIT.- Note - Each item in the returned dictionary will be - copy.deepcopybefore it is used. This is to ensure that each dataloader (e.g. train, eval) will be accumulating metrics separately.- To share a metric across all dataloaders, wrap it with - MetricSpec(metric=metric, share=False).- Parameters
- is_train (bool) โ Whether the training metrics or evaluation metrics should be returned. 
- Returns
- Dict[str, Metric] โ A mapping of the metric name to a Metric. 
 
 - abstract loss(outputs, batch, *args, **kwargs)[source]#
- Compute the loss of the model given - outputsfrom- forward()and a- Batchof data from the dataloader. The- Trainerwill call- .backward()on the returned loss.- Parameters
- outputs (Any) โ The output of the forward pass. 
- batch (Batch) โ The output batch from dataloader. 
 
- Returns
- Tensor | Sequence[Tensor] โ The loss as a - torch.Tensor.
 - Example: - import torch.nn.functional as F def loss(self, outputs, batch): # pass batches and :meth:`forward` outputs to the loss _, targets = batch # discard inputs from batch return F.cross_entropy(outputs, targets) - The outputs of - forward()are passed to- loss()by the trainer:- for batch in train_dataloader: optimizer.zero_grad() outputs = model.forward(batch) loss = model.loss(outputs, batch) loss.backward() 
 - metrics(train=False)[source]#
- Get metrics for evaluating the model. Metrics should be instances of - torchmetrics.Metricdefined in- __init__(). This format enables accurate distributed logging. Metrics consume the outputs of- validate(). To track multiple metrics, return a list of metrics in a MetricCollection.- Parameters
- train (bool, optional) โ True to return metrics that should be computed during training and False otherwise. This flag is set automatically by the - Trainer. Default:- False.
- Returns
- Metric or MetricCollection โ An instance of - Metricor MetricCollection.
 - Warning - Each metric keeps states which are updated with data seen so far. As a result, different metric instances should be used for training and validation. See: https://torchmetrics.readthedocs.io/en/latest/pages/overview.html for more details. - Example: - from torchmetrics.classification import MulticlassAccuracy from composer.models.loss import CrossEntropyLoss def __init__(self, num_classes): super().__init__() self.train_acc = MulticlassAccuracy(num_classes=num_classes, average='micro') # torchmetric self.val_acc = MulticlassAccuracy(num_classes=num_classes, average='micro') self.val_loss = CrossEntropyLoss() def metrics(self, train: bool = False): return self.train_acc if train else MetricCollection([self.val_acc, self.val_loss]) 
 - update_metric(batch, outputs, metric)[source]#
- Update the given metric. - Parameters
- batch โ The dataloader batch 
- outputs โ The output from - eval_forward()
- metric (Metric) โ The metric to update.