Source code for composer.core.state

# Copyright 2021 MosaicML. All Rights Reserved.

"""The state of the trainer."""
from __future__ import annotations

import contextlib
import logging
import textwrap
import warnings
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, List, Optional, Sequence, Union, cast

import torch
import torch.nn.modules.utils
from torch.nn.parallel import DistributedDataParallel
from torch.optim import Optimizer

from composer.core.precision import Precision
from composer.core.serializable import Serializable
from composer.core.time import Time, Timer, TimeUnit
from composer.utils import dist, ensure_tuple

if TYPE_CHECKING:
    import deepspeed

    import composer.core.types as types
    from composer.core.algorithm import Algorithm
    from composer.core.callback import Callback
    from composer.core.evaluator import Evaluator
    from composer.profiler import Profiler

__all__ = ["State"]

logger = logging.getLogger(__name__)


def _default_precision_factory() -> Callable[[Union[str, Precision]], ContextManager]:
    """Returns a context manager to automatically cast to a specific precision.

    Args:
        precision (str or Precision): Precision for the context
    """
    if torch.cuda.is_available():
        return lambda precision: torch.cuda.amp.autocast(Precision(precision) == Precision.AMP)
    else:

        def null(precision):
            assert Precision(
                precision) != Precision.AMP, "Precision AMP is only available when `torch.cuda.is_available() == True`."
            return contextlib.nullcontext()

        return null


def _ensure_backwards_compatible_checkpointing(state_dict: Dict[str, Any]):
    # v0.4.1 removed the leading underscores for the keys in the state_dict
    # It also renamed _is_model_ddp_wrapped to is_model_ddp
    state = {}
    for k, v in state_dict.items():
        if k == "_is_model_ddp_wrapped":
            k = "is_model_ddp"
        if k.startswith("_"):
            k = k[1:]
        state[k] = v
    return state


_STATE_DICT_SERIALIZED_ATTRIBUTES = [
    # List of attributes that are serialized with state_dict
    # Only the attributes listed in state.serialized_attributes will actually be saved.
    "model",
    "optimizers",
    "schedulers",
    "algorithms",
    "callbacks",
    "scaler",
    "timer",
]


[docs]class State(Serializable): """The state of the trainer. Contains variables that the trainer tracks throughout the training loop. Note that all the necessary parts (i.e., :attr:`serialized_attributes`) of state are serialized when the trainer is checkpointed so that it can be used restore the trainer and continue training from a checkpoint. :mod:`~composer.algorithms` are able to modify an instance of this class in-place. .. note:: An instance of this class is automatically constructed by the :class:`~.Trainer` constructor. A user need not instantiate this class. Args: model (torch.nn.Module): The model, typically as a subclass of :class:`~.ComposerModel`. rank_zero_seed (int): The seed used on the rank zero process. It is assumed that each rank's seed is ``rank_zero_seed + dist.get_global_rank()``. grad_accum (int): The number of gradient accumulation steps to use. With this argument, micro batch size for each device becomes ``microbatch_size = train_batch_size / (num_devices * grad_accum)``. train_dataloader (types.DataLoader, DataSpec, or dict): The :class:`~.types.DataLoader`, :class:`~.DataSpec`, or dict of :class:`~.DataSpec` kwargs to used for training. evaluators (evaluator.Evaluator | Sequence[evaluator.Evaluator]): The evaluators contain the evaluation dataset(s) used for evaluation with specific metrics. max_duration (str or Time): The maximum duration to train for. precision (str | Precision): The numerical precision to use for training. See :class:`~.Precision` for the supported precisions. precision_context (Callable[[Precision], ContextManager]): Function to produce a context manager to mandate precision. optimizers (torch.optim.Optimizer | Sequence[torch.optim.Optimizer], optional): The optimizer being used to train the model. Multiple optimizers are not currently supported. schedulers (types.PyTorchScheduler | Sequence[types.PyTorchScheduler], optional): The learning rate scheduler (can also be a list or tuple of schedulers). scaler (torch.cuda.amp.GradScaler, optional): The gradient scaler in use for mixed precision training. algorithms (Algorithm | Sequence[Algorithm], optional): The algorithms used for training. callbacks (Callback | Sequence[Callback], optional): The callbacks used for training. profiler (Optional[Profiler]): The Composer profiler. Attributes: batch (types.Batch): The batch. This will be the entire batch during the :attr:`.Event.AFTER_DATALOADER`, or a microbatch between :attr:`.Event.BATCH_START` and :attr:`.Event.BATCH_END`. batch_num_samples (int): The number of samples in the :attr:`batch`. batch_num_tokens (int): The number of tokens in the :attr:`batch`. current_metrics (Dict[str, Dict[str, Any]]): The current computed metrics, organized by dataloader label and then by metric name. The train dataloader is labeled ``'train'``. If not using an :class:`.Evaluator`, the eval dataloader is labeled ``'eval'``. Otherwise, the evaluator label is used. For example: >>> trainer = Trainer( ... ..., ... compute_training_metrics=True, ... train_dataloader=train_dataloader, ... eval_dataloader=eval_dataloader, ... ) >>> trainer.fit() >>> trainer.state.current_metrics {'train': {'Accuracy': tensor(...)}, 'eval': {'Accuracy': tensor(...)}} Or, when using an :class:`.Evaluator`: .. testsetup:: eval_1_dl = eval_dataloader eval_2_dl = eval_dataloader >>> from torchmetrics import Accuracy >>> from composer.core import Evaluator >>> trainer = Trainer( ... ..., ... compute_training_metrics=True, ... train_dataloader=train_dataloader, ... eval_dataloader=[ ... Evaluator(label='eval1', dataloader=eval_1_dl, metrics=Accuracy()), ... Evaluator(label='eval2', dataloader=eval_2_dl, metrics=Accuracy()), ... ], ... ) >>> trainer.fit() >>> trainer.state.current_metrics {'train': {'Accuracy': tensor(...)}, 'eval1': {'Accuracy': tensor(...)}, 'eval2': {'Accuracy': tensor(...)}} loss (torch.Tensor | Sequence[torch.Tensor]): The most recently computed loss. outputs (torch.Tensor | Sequence[torch.Tensor]): The most recently computed output from the model's forward pass. timer (Timer): The timer that tracks training loop progress. serialized_attributes (List[str]): The names of the attribute which are serialized in a checkpoint. By default, the following attributes are serialized: +-----------------------+-------------------------------------------------------------+ | Attribute | Description | +=======================+=============================================================+ | model | The model under training. | +-----------------------+-------------------------------------------------------------+ | optimizers | The optimizers being used to train the model. | +-----------------------+-------------------------------------------------------------+ | schedulers | The learning rate schedulers. | +-----------------------+-------------------------------------------------------------+ | algorithms | The algorithms used for training. | +-----------------------+-------------------------------------------------------------+ | callbacks | The callbacks used for training. | +-----------------------+-------------------------------------------------------------+ | scaler | The gradient scaler in use for mixed precision training. | +-----------------------+-------------------------------------------------------------+ | timer | The timer that tracks training loop progress. | +-----------------------+-------------------------------------------------------------+ | rank_zero_seed | The seed of the rank zero process. | +-----------------------+-------------------------------------------------------------+ | current_metrics | The current metrics. | +-----------------------+-------------------------------------------------------------+ """ _max_duration: Time[int] _steps_per_epoch: Optional[int] batch: types.Batch batch_num_samples: int batch_num_tokens: int loss: Union[torch.Tensor, Sequence[torch.Tensor]] outputs: Union[torch.Tensor, Sequence[torch.Tensor]] _schedulers: List[types.PyTorchScheduler] def __init__( self, # model model: torch.nn.Module, # stopping conditions max_duration: Union[str, Time[int]], rank_zero_seed: int, # data configurations train_dataloader: types.DataLoader, evaluators: Optional[Union[Evaluator, Sequence[Evaluator]]] = None, grad_accum: int = 1, # precision precision: Union[str, Precision] = Precision.FP32, precision_context: Callable[[Precision], ContextManager] = _default_precision_factory(), # optimizers optimizers: Optional[Union[Optimizer, Sequence[Optimizer]]] = None, # scaler scaler: Optional[torch.cuda.amp.grad_scaler.GradScaler] = None, # algorithms and callbacks algorithms: Optional[Union[Algorithm, Sequence[Algorithm]]] = None, callbacks: Optional[Union[Callback, Sequence[Callback]]] = None, # steps per epoch steps_per_epoch: Optional[int] = None, ): self.rank_zero_seed = rank_zero_seed self.model = model self.grad_accum = grad_accum self.train_dataloader = train_dataloader self.evaluators = list(ensure_tuple(evaluators)) self.max_duration = max_duration self.steps_per_epoch = steps_per_epoch self.timer = Timer() self._precision = Precision(precision) self._precision_context = precision_context if optimizers is None: self._optimizers = [] else: self._optimizers = list(ensure_tuple(optimizers)) self._schedulers = [] self.scaler = scaler self._algorithms = list(ensure_tuple(algorithms)) self._callbacks = list(ensure_tuple(callbacks)) self.profiler: Optional[Profiler] = None # These attributes will be serialized using .state_dict(), and loaded with .load_state_dict() # All other attributes will not be serialized. # For simplicity, omit the leading underscore for private attributes. # For example, even though the optimizers are stored on the state # as the "_optimizers" attribute, here we specify just "optimizers" self.serialized_attributes = [ "model", "optimizers", "schedulers", "algorithms", "callbacks", "scaler", "timer", "rank_zero_seed", "current_metrics", ] self.current_metrics: Dict[str, Dict[str, Any]] = {} @property def seed(self): """The seed for the current rank.""" return self.rank_zero_seed + dist.get_global_rank() @property def max_duration(self): """The maximum training duration.""" return self._max_duration @max_duration.setter def max_duration(self, max_duration: Union[str, Time[int]]): if isinstance(max_duration, str): max_duration = cast(Time[int], Time.from_timestring(max_duration)) if max_duration.unit == TimeUnit.DURATION: raise ValueError("TimeUnit.DURATION is not allowed as a unit for max_duration") self._max_duration = max_duration
[docs] def get_elapsed_duration(self) -> Time[float]: """Get the elapsed training duration. Returns: Time: The elapsed duration, in :attr:`TimeUnit.DURATION`. ``Time(0.0, TimeUnit.DURATION)`` represents the beginning of training and ``Time(1.0, TimeUnit.DURATION)`` represents a completed training process. """ return self.timer.get(self.max_duration.unit) / self.max_duration
@property def optimizers(self): return self._optimizers @optimizers.setter def optimizers(self, optimizers: Union[Optimizer, Sequence[Optimizer]]): self._optimizers[:] = ensure_tuple(optimizers) @property def schedulers(self): return self._schedulers @schedulers.setter def schedulers(self, schedulers: types.PyTorchScheduler): self._schedulers[:] = ensure_tuple(schedulers) @property def callbacks(self): return self._callbacks @callbacks.setter def callbacks(self, callbacks: Sequence[Callback]): self._callbacks[:] = callbacks @property def algorithms(self): return self._algorithms @algorithms.setter def algorithms(self, algorithms: Sequence[Algorithm]): self._algorithms[:] = algorithms
[docs] def state_dict(self) -> Dict[str, Any]: """Returns the state as a :class:`dict`.""" state_dict = {} for attribute_name in self.serialized_attributes: attribute_value = getattr(self, attribute_name) if attribute_name == "model": # Save model directly instead of by class name, since model may be wrapped by DistributedDataParallel # If it is DDP wrapped, do not save the `module.` prefix, as that is an implmentation detail model_state = attribute_value.state_dict() if self.is_model_ddp: torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(model_state, "module.") serialized_value = model_state else: if attribute_name in _STATE_DICT_SERIALIZED_ATTRIBUTES: serialized_value = { type(obj).__qualname__: obj.state_dict() for obj in ensure_tuple(attribute_value) } else: serialized_value = attribute_value state_dict[attribute_name] = serialized_value return state_dict
[docs] def load_model_state(self, state_dict: Dict[str, Any], strict: bool): """Loads the model's state from a state_dict. Args: state_dict (Dict[str, Any]): The state dict, generated from a previous call to :meth:`state_dict`. strict (bool): Whether the keys (i.e., model parameter names) in the model state dict should perfectly match the keys in the model instance. """ if state_dict.get("is_model_ddp", False) and not self.is_model_ddp: # This check is for backwards compatibility, as pre-v0.6.0 checkpoints serialized the state # with the `module.` prefix torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(state_dict['model'], "module.") missing_keys, unexpected_keys = self.model.load_state_dict(state_dict['model'], strict=strict) if len(missing_keys) > 0: logger.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}") if len(unexpected_keys) > 0: logger.warning(f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}")
[docs] def load_state_dict(self, state: Dict[str, Any], strict: bool = False): """Loads the state. Args: state (Dict[str, Any]): object returned from call to :meth:`state_dict`. strict (bool): whether the keys in the ``state["model"]`` should perfectly match the keys in the ``self.model``. Defaults to False. """ state = _ensure_backwards_compatible_checkpointing(state) for attribute_name, serialized_value in state.items(): if attribute_name not in self.serialized_attributes: # it's possible some attributes we removed continue if attribute_name == "model": self.load_model_state(state, strict=strict) continue state_field_value = getattr(self, attribute_name) if attribute_name in _STATE_DICT_SERIALIZED_ATTRIBUTES: for target in ensure_tuple(state_field_value): if type(target).__qualname__ not in serialized_value: warnings.warn( f"{type(target).__qualname__} is not in the state_dict. Its state will not be restored.", category=UserWarning) continue source = serialized_value[type(target).__qualname__] target.load_state_dict(source) else: # direct serialization try: setattr(self, attribute_name, serialized_value) except AttributeError: # ignore AttributeError for properties that have getters but not setters. pass
@property def steps_per_epoch(self): """int: The maximum number of steps (batches) per epoch.""" if self._steps_per_epoch is None: return len(self.train_dataloader) return self._steps_per_epoch @steps_per_epoch.setter def steps_per_epoch(self, steps_per_epoch: Optional[int]): try: dataloader_len = len(self.train_dataloader) except (TypeError, NotImplementedError): dataloader_len = None if dataloader_len is not None and steps_per_epoch is not None and steps_per_epoch > dataloader_len: warnings.warn( textwrap.dedent(f"""\ SubsetNumBatchesWarning: The steps_per_epoch({steps_per_epoch}) is greater than the number of batches in the training dataloader ({dataloader_len})""")) self._steps_per_epoch = steps_per_epoch @property def precision(self): """The numerical precision to use for training. See :class:`~.Precision` for the supported precisions. """ return self._precision @precision.setter def precision(self, precision: Union[str, Precision]): self._precision = Precision(precision) @property def batch_pair(self) -> types.BatchPair: """:attr:`~.types.BatchPair`: The current batch, represented as a :attr:`~.types.BatchPair`. Raises: TypeError: If the current batch is not a :attr:`~.types.BatchPair`. """ from composer.core.types import as_batch_pair return as_batch_pair(self.batch) @property def batch_dict(self) -> types.BatchDict: """:attr:`~.types.BatchDict`: The current batch, represented as a :attr:`~.types.BatchDict`. Raises: TypeError: If the current batch is not a :attr:`~.types.BatchDict`. """ from composer.core.types import as_batch_dict return as_batch_dict(self.batch) @property def precision_context(self): return self._precision_context(self.precision) @property def is_model_deepspeed(self) -> bool: """Whether :attr:`model` is an instance of a :class:`~deepspeed.DeepSpeedEngine`.""" try: import deepspeed except ImportError: return False else: return isinstance(self.model, deepspeed.DeepSpeedEngine) @property def is_model_ddp(self): """Whether :attr:`model` is an instance of a :class:`.DistributedDataParallel`.""" return isinstance(self.model, DistributedDataParallel) @property def deepspeed_model(self) -> deepspeed.DeepSpeedEngine: """Cast :attr:`model` to :class:`~deepspeed.DeepSpeedEngine`.""" if self.is_model_deepspeed: return cast("deepspeed.DeepSpeedEngine", self.model) raise TypeError("state.model is not a DeepSpeed model")