Source code for composer.models.base

# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""The ComposerModel base interface."""
from __future__ import annotations

import abc
import copy
from typing import Any, Dict, Optional, Sequence, Union

import torch
from torch import Tensor
from torchmetrics import Metric

from composer.core import Batch
from composer.loggers import Logger

__all__ = ['ComposerModel']


[docs]class ComposerModel(torch.nn.Module, abc.ABC): """The interface needed to make a PyTorch model compatible with :class:`composer.Trainer`. To create a :class:`.Trainer`\\-compatible model, subclass :class:`.ComposerModel` and implement :meth:`forward` and :meth:`loss`. For full functionality (logging and validation), implement :meth:`get_metrics` and :meth:`eval_forward`. See the :doc:`Composer Model walk through </composer_model>` for more details. Minimal Example: .. code-block:: python import torchvision import torch.nn.functional as F from composer.models import ComposerModel class ResNet18(ComposerModel): def __init__(self): super().__init__() self.model = torchvision.models.resnet18() # define PyTorch model in __init__. def forward(self, batch): # batch is the output of the dataloader # specify how batches are passed through the model inputs, _ = batch return self.model(inputs) def loss(self, outputs, batch): # pass batches and `forward` outputs to the loss _, targets = batch return F.cross_entropy(outputs, targets) Attributes: logger (Optional[Logger]): The training :class:`.Logger`. The trainer sets the :class:`.Logger` on the:attr:`.Event.INIT` event. """ def __init__(self) -> None: super().__init__() self.logger: Optional[Logger] = None def __deepcopy__(self, memo: dict): # From https://stackoverflow.com/questions/1500718/how-to-override-the-copy-deepcopy-operations-for-a-python-object # The `logger` should not be copied cls = self.__class__ result = cls.__new__(cls) memo[id(self)] = result for k, v in self.__dict__.items(): if k == 'logger': copied_v = v else: copied_v = copy.deepcopy(v, memo) setattr(result, k, copied_v) return result def __copy__(self): # From https://stackoverflow.com/questions/1500718/how-to-override-the-copy-deepcopy-operations-for-a-python-object # Need to manually define `__copy__` so it does not rely on `__getstate__`, which would not copy the logger. cls = self.__class__ result = cls.__new__(cls) result.__dict__.update(self.__dict__) return result def __getstate__(self): # Don't pickle the logger state = self.__dict__.copy() state['logger'] = None return state
[docs] @abc.abstractmethod def forward(self, batch: Batch) -> Union[Tensor, Sequence[Tensor]]: """Compute model output given a batch from the dataloader. Args: batch (~composer.core.types.Batch): The output batch from dataloader. Returns: Tensor | Sequence[Tensor]: The result that is passed to :meth:`loss` as the parameter :attr:`outputs`. .. warning:: This method is different from vanilla PyTorch ``model.forward(x)`` or ``model(x)`` as it takes a batch of data that has to be unpacked. Example: .. code-block:: python def forward(self, batch): # batch is the output of the dataloader inputs, _ = batch return self.model(inputs) The outputs of :meth:`forward` are passed to :meth:`loss` by the trainer: .. code-block:: python for batch in train_dataloader: optimizer.zero_grad() outputs = model.forward(batch) loss = model.loss(outputs, batch) loss.backward() """ pass
[docs] @abc.abstractmethod def loss(self, outputs: Any, batch: Batch, *args, **kwargs) -> Union[Tensor, Sequence[Tensor]]: """Compute the loss of the model given ``outputs`` from :meth:`forward` and a :class:`~composer.core.types.Batch` of data from the dataloader. The :class:`.Trainer` will call ``.backward()`` on the returned loss. Args: outputs (Any): The output of the forward pass. batch (~composer.core.types.Batch): The output batch from dataloader. Returns: Tensor | Sequence[Tensor]: The loss as a :class:`torch.Tensor`. Example: .. code-block:: python import torch.nn.functional as F def loss(self, outputs, batch): # pass batches and :meth:`forward` outputs to the loss _, targets = batch # discard inputs from batch return F.cross_entropy(outputs, targets) The outputs of :meth:`forward` are passed to :meth:`loss` by the trainer: .. code-block:: python for batch in train_dataloader: optimizer.zero_grad() outputs = model.forward(batch) loss = model.loss(outputs, batch) loss.backward() """ pass
[docs] def eval_forward( self, batch: Any, outputs: Optional[Any] = None, ) -> Any: """Run the evaluation forward pass. By default, it returns the ``outputs`` if they are not None. Otherwise, ``self(batch)`` is returned. Override this method for models that require custom validation logic -- e.g. self-supervised learning. Args: batch: The dataloader batch. outputs (Any, optional): If training, the outputs from the forward pass. Otherwise, None. Returns: Any: The evaluation outputs. """ return outputs if outputs is not None else self.forward(batch)
[docs] def update_metric( self, batch: Any, outputs: Any, metric: Metric, ) -> Optional[Dict]: """Update the given metric. Args: batch: The dataloader batch outputs: The output from :meth:`eval_forward` metric (Metric): The metric to update. Returns: Optional[Dict]: Optionally return metric results to be stored in state. """ raise NotImplementedError()
[docs] def get_metrics(self, is_train: bool) -> Dict[str, Metric]: """Get the metrics. This method will be called by the trainer immediately after :attr:`.Event.INIT`. .. note:: Each item in the returned dictionary will be ``copy.deepcopy`` before it is used. This is to ensure that each dataloader (e.g. train, eval) will be accumulating metrics separately. To share a metric across all dataloaders, wrap it with ``MetricSpec(metric=metric, share=False)``. Args: is_train (bool): Whether the training metrics or evaluation metrics should be returned. Returns: Dict[str, Metric]: A mapping of the metric name to a Metric. """ return {}