ComposerClassifier#

class composer.models.ComposerClassifier(module, train_metrics=None, val_metrics=None, loss_fn=<function soft_cross_entropy>)[source]#

A convenience class that creates a ComposerModel for classification tasks from a vanilla PyTorch model. ComposerClassifier requires batches in the form: (input, target) and includes a basic classification training loop with a loss function loss_fn which takes in the modelโ€™s outputs and the labels.

Parameters
  • module (Module) โ€“ A PyTorch neural network module.

  • train_metrics (Metric | MetricCollection, optional) โ€“ A torchmetric or collection of torchmetrics to be computed on the training set throughout training.

  • val_metrics (Metric | MetricCollection, optional) โ€“ A torchmetric or collection of torchmetrics to be computed on the validation set throughout training.

  • loss_fn (Callable, optional) โ€“ Loss function to use. This loss function should have at least two arguments: 1) the output of the model and 2) target i.e. labels from the dataset.

Returns

ComposerClassifier โ€“ An instance of ComposerClassifier.

Example:

import torchvision
from composer.models import ComposerClassifier

pytorch_model = torchvision.models.resnet18(pretrained=False)
model = ComposerClassifier(pytorch_model)