ComposerClassifier#
- class composer.models.ComposerClassifier(module, train_metrics=None, val_metrics=None, loss_fn=<function soft_cross_entropy>)[source]#
A convenience class that creates a
ComposerModel
for classification tasks from a vanilla PyTorch model.ComposerClassifier
requires batches in the form: (input
,target
) and includes a basic classification training loop with a loss function loss_fn which takes in the modelโs outputs and the labels.- Parameters
module (Module) โ A PyTorch neural network module.
train_metrics (Metric | MetricCollection, optional) โ A torchmetric or collection of torchmetrics to be computed on the training set throughout training.
val_metrics (Metric | MetricCollection, optional) โ A torchmetric or collection of torchmetrics to be computed on the validation set throughout training.
loss_fn (Callable, optional) โ Loss function to use. This loss function should have at least two arguments: 1) the output of the model and 2)
target
i.e. labels from the dataset.
- Returns
ComposerClassifier โ An instance of
ComposerClassifier
.
Example:
import torchvision from composer.models import ComposerClassifier pytorch_model = torchvision.models.resnet18(pretrained=False) model = ComposerClassifier(pytorch_model)