BaseMosaicModel

class composer.models.BaseMosaicModel[source]

Bases: torch.nn.modules.module.Module, abc.ABC

The minimal interface needed to use a model with composer.trainer.Trainer.

abstract forward(batch: composer.core.types.Batch) composer.core.types.Tensors[source]

Compute model output given an input.

Parameters

batch (Batch) – The input batch for the forward pass.

Returns

Tensors – The result that is passed to loss() as a Tensors object.

abstract loss(outputs: Any, batch: composer.core.types.Batch, *args, **kwargs) composer.core.types.Tensors[source]

Compute the loss of the model.

Parameters
  • outputs (Any) – The output of the forward pass.

  • batch (Batch) – The input batch from dataloader.

Returns

Tensors – The loss as a Tensors object.

abstract metrics(train: bool = False) Union[Metric, torchmetrics.collections.MetricCollection][source]

Get metrics for evaluating the model.

Warning

Each metric keeps states which are updated with data seen so far. As a result, different metric instances should be used for training and validation. See: https://torchmetrics.readthedocs.io/en/latest/pages/overview.html for more details.

Parameters

train (bool, optional) – True to return metrics that should be computed during training and False otherwise. (default: False)

Returns

Metrics – A Metrics object.

abstract validate(batch: composer.core.types.Batch) Tuple[Any, Any][source]

Compute model outputs on provided data.

The output of this function will be directly used as input to all metrics returned by metrics().

Parameters

batch (Batch) – The data to perform validation with. Specified as a tuple of tensors (input, target).

Returns

Tuple[Any, Any] – Tuple that is passed directly to the update() methods of the metrics returned by metrics(). Most often, this will be a tuple of the form (predictions, targets).