ComposerClassifier#
- class composer.models.ComposerClassifier(module, num_classes=None, train_metrics=None, val_metrics=None, loss_fn=<function soft_cross_entropy>)[source]#
A convenience class that creates a
ComposerModelfor classification tasks from a vanilla PyTorch model.ComposerClassifierrequires batches in the form: (input,target) and includes a basic classification training loop with a loss function loss_fn which takes in the modelโs outputs and the labels.- Parameters
module (Module) โ A PyTorch neural network module.
num_classes (int, optional) โ The number of output classes. Required if self.module does not have a num_classes parameter.
train_metrics (Metric | MetricCollection, optional) โ A torchmetric or collection of torchmetrics to be computed on the training set throughout training. (default:
MulticlassAccuracy)val_metrics (Metric | MetricCollection, optional) โ A torchmetric or collection of torchmetrics to be computed on the validation set throughout training. (default:
composer.metrics.CrossEntropy,MulticlassAccuracy)loss_fn (Callable, optional) โ Loss function to use. This loss function should have at least two arguments: 1) the output of the model and 2)
targeti.e. labels from the dataset.
- Returns
ComposerClassifier โ An instance of
ComposerClassifier.
Example:
import torchvision from composer.models import ComposerClassifier pytorch_model = torchvision.models.resnet18(pretrained=False) model = ComposerClassifier(pytorch_model, num_classes=1000)