# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0
"""The state of the trainer."""
from __future__ import annotations
import collections.abc
import logging
import textwrap
import warnings
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Union, cast
import torch
import torch.nn.modules.utils
from packaging import version
from torch.nn.parallel import DistributedDataParallel
from torch.optim import Optimizer
from torch.utils.data import DataLoader, Dataset
from torchmetrics import Metric
from torchmetrics.metric import jit_distributed_available
from composer.core.data_spec import DataSpec
from composer.core.event import Event
from composer.core.precision import Precision
from composer.core.serializable import Serializable
from composer.core.time import Time, Timestamp, TimeUnit
from composer.devices import Device
from composer.utils import batch_get, batch_set, dist, ensure_tuple, get_composer_env_dict, is_model_deepspeed
if TYPE_CHECKING:
import deepspeed
import composer.core.types as types
from composer.core.algorithm import Algorithm
from composer.core.callback import Callback
from composer.core.evaluator import Evaluator
from composer.core.passes import AlgorithmPass
from composer.loggers import Logger
from composer.profiler import Profiler
__all__ = ['State']
log = logging.getLogger(__name__)
@contextmanager
def fsdp_state_dict_type_context(module: torch.nn.Module, state_dict_type: str = 'full'):
"""Context manager for materializing or loading an fsdp module's state dict.
Args:
module (torch.nn.Module): The torch module that you want to call `state_dict()`
or `load_state_dict()` on.
state_dict_type (str, optional): which of the three state dict types you want to use.
choices are ['full', 'sharded', 'local']. Defaults to 'full'.
* 'full': the full, unsharded state dict materialized only on rank 0 with cpu_offload if necessary
* 'local': the sharded, flattened state_dict, where each rank only gets a single shard.
* 'sharded': the sharded, unflattened state_dict, where each rank only gets a single shard.
See torch.distributed.fsdp.StateDictType for more info.
Raises:
RuntimeError: if your torch version is earlier than 1.13.0 because FSDP is not available for those versions.
NotImplementedError: if you specify a state_dict_type not in ['full', 'sharded', 'local'].
"""
if version.parse(torch.__version__) < version.parse('1.13.0'):
raise RuntimeError('To use FSDP with Composer, you must use torch>=1.13.0.')
from torch.distributed.fsdp import FullStateDictConfig
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import LocalStateDictConfig, StateDictType
# torch forgot to put ShardedStateDictConfig in torch/distributed/fsdp/__init__.py, so we
# have to import it this way.
from torch.distributed.fsdp.fully_sharded_data_parallel import ShardedStateDictConfig
# Full is the full monolithic state dict materialized in memory on just rank 0
# with offloading to cpu if necessary
if state_dict_type == 'full':
state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
fsdp_state_dict_type = StateDictType.FULL_STATE_DICT
# Sharded is sharded state dict, but unflattened parameters (not useful for FSDP, but
# useful if you plan to use the state dict outside of FSDP).
elif state_dict_type == 'sharded':
state_dict_config = ShardedStateDictConfig()
fsdp_state_dict_type = StateDictType.SHARDED_STATE_DICT
# Local is the FSDP standard sharded, flattened parameters. This is what the parameters
# are formatted to for a single rank's FSDP module.
elif state_dict_type == 'local':
state_dict_config = LocalStateDictConfig()
fsdp_state_dict_type = StateDictType.LOCAL_STATE_DICT
else:
raise NotImplementedError(f'No valid FSDP state_dict_type for {state_dict_type}')
with FSDP.state_dict_type(module, state_dict_type=fsdp_state_dict_type, state_dict_config=state_dict_config):
yield
def fsdp_get_optim_state_dict(model: torch.nn.Module,
optim: torch.optim.Optimizer,
state_dict_type: str = 'full') -> Dict[str, Any]:
"""Materializes a given model's optimizer's state_dict.
Args:
model (torch.nn.Module): The model that the optimizer corresponds to.
optim (torch.optim.Optimizer): The optimizer that you want a state dict for.
state_dict_type (str, optional): which of the three state dict types you want to use.
choices are ['full', 'sharded', 'local']. Defaults to 'full'.
* 'full': the full, unsharded state dict materialized only on rank 0
* 'local': the sharded, flattened state_dict, where each rank only gets a single shard.
* 'sharded': the sharded, unflattened state_dict, where each rank only gets a single shard.
Raises:
RuntimeError: if your torch version is earlier than 1.13.0 because FSDP is not available for those versions.
NotImplementedError: if you specify a state_dict_type not in ['full', 'sharded', 'local'].
Returns:
Dict[str, Any]: The state_dict for the given optimizer.
"""
if version.parse(torch.__version__) < version.parse('1.13.0'):
raise RuntimeError('To use FSDP with Composer, you must use torch>=1.13.0.')
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
if state_dict_type == 'full':
# Converts local state dict to full.
return FSDP.full_optim_state_dict(model=model, optim=optim)
elif state_dict_type == 'sharded':
# Converts local state dict to sharded.
return FSDP.sharded_optim_state_dict(model=model, optim=optim)
elif state_dict_type == 'local':
# State dict is already local, so just return state dict.
return optim.state_dict()
else:
raise NotImplementedError(f'No valid FSDP state_dict_type for {state_dict_type}')
def get_fsdp_sharded_optim_state_dict(full_optim_state_dict: Dict[str, Any], model: torch.nn.Module):
if version.parse(torch.__version__) < version.parse('1.13.0'):
raise RuntimeError('To use FSDP with Composer, you must use torch>=1.13.0.')
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
log.debug(
f'Scattering optimizer state dict with keys {full_optim_state_dict.keys()} and model of type {type(model)}')
return FSDP.scatter_full_optim_state_dict(full_optim_state_dict=full_optim_state_dict, model=model)
def get_fsdp_full_optim_state_dict(model: torch.nn.Module, optim: torch.optim.Optimizer, rank0_only: bool = True):
if version.parse(torch.__version__) < version.parse('1.13.0'):
raise RuntimeError('To use FSDP with Composer, you must use torch>=1.13.0.')
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
return FSDP.full_optim_state_dict(model=model, optim=optim, rank0_only=rank0_only)
def _ensure_backwards_compatible_checkpointing(state_dict: Dict[str, Any]):
# v0.4.1 removed the leading underscores for the keys in the state_dict
# It also renamed _is_model_ddp_wrapped to is_model_ddp
state = {}
for attribute_name, serialized_value in state_dict.items():
if attribute_name == '_is_model_ddp_wrapped':
attribute_name = 'is_model_ddp'
if attribute_name.startswith('_'):
attribute_name = attribute_name[1:]
# Torchmetrics adds a new attribute as of 0.11 which must be added to deserialized metrics
if attribute_name == 'train_metrics':
for metric_name in serialized_value.keys():
metric = serialized_value[metric_name]
if not hasattr(metric, 'distributed_available_fn'):
metric.distributed_available_fn = jit_distributed_available
serialized_value[metric_name] = metric
elif attribute_name == 'eval_metrics':
for evaluator_name, eval_metrics in serialized_value.items():
for metric_name in eval_metrics.keys():
metric = eval_metrics[metric_name]
if not hasattr(metric, 'distributed_available_fn'):
metric.distributed_available_fn = jit_distributed_available
serialized_value[evaluator_name][metric_name] = metric
state[attribute_name] = serialized_value
return state
_STATE_DICT_SERIALIZED_ATTRIBUTES = [
# List of attributes that are serialized with state_dict
# Only the attributes listed in state.serialized_attributes will actually be saved.
'model',
'optimizers',
'schedulers',
'algorithms',
'callbacks',
'scaler',
'timestamp',
]
[docs]class State(Serializable):
"""The state of the trainer.
Contains variables that the trainer tracks throughout the training loop. Note that all the necessary parts (i.e.,
:attr:`serialized_attributes`) of state are serialized when the trainer is checkpointed so that it can be used to
restore the trainer and continue training from a checkpoint. :mod:`~composer.algorithms` are able to modify an
instance of this class in-place.
.. note::
An instance of this class is automatically constructed by the :class:`~.Trainer` constructor. A user need
not instantiate this class.
Args:
model (torch.nn.Module): The model, typically as a subclass of :class:`~.ComposerModel`.
rank_zero_seed (int): The seed used on the rank zero process. It is assumed that each rank's seed is
``rank_zero_seed + dist.get_global_rank()``.
run_name (str): The name for this training run.
device (Device): The device used by this process. The trainer moves the model and loaded data to this device.
grad_accum (int, optional): The number of gradient accumulation steps to use. With this argument, micro batch
size for each device becomes ``microbatch_size = train_batch_size / (num_devices * grad_accum)``.
eval_batch_split (int, optional): The mirror of grad_accum for eval. With this argument, micro batch
size for each device becomes ``microbatch_size = eval_batch_size / (num_devices * eval_batch_split)``.
device_train_microbatch_size (int, optional): The microbatch size for each device during training.
auto_microbatching (bool, optional): Whether automatic microbatching is enabled.
using_device_microbatch_size (bool, optional): Whether device_train_microbatch_size is set by the user.
train_dataloader (Iterable, optional): Dataloader used for training
evaluators (Evaluator | Evaluators, optional): :class:`.Evaluator` used for evaluation.
dataloader (Iterable, optional): The active DataLoader.
dataloader_len (int | Time[int], optional): The number of batches per dataloader iteration (e.g. epoch).
The trainer will yield the first ``dataloader_len`` batches per iteration. If ``-1`` (the default),
the entire dataloader will be iterated over.
dataloader_label (str, optional): The name for the dataloader. Required if ``dataloader`` is specified.
(default: ``None``)
By convention, the training dataloader is called ``'train'``. The evaluator dataloader is called
``'eval'``, or when multiple evaluators are used, the name of the evaluator.
dataset_state (Dict[str, Any], optional): Mapping of dataset split to its iteration state for resumption.
dataset_resumption (Dict[str, Any], optional): Mapping of dataset split to whether resumption is used.
max_duration (str | Time, optional): The maximum duration to train for. (default: ``None``)
precision (str | Precision): The numerical precision to use for training. See :class:`~.Precision` for
the supported precisions.
optimizers (torch.optim.Optimizer | Sequence[torch.optim.Optimizer], optional): The optimizer being used to
train the model. Multiple optimizers are not currently supported.
schedulers (types.PyTorchScheduler | Sequence[types.PyTorchScheduler], optional):
The learning rate scheduler (can also be a list or tuple of schedulers).
scaler (torch.cuda.amp.GradScaler, optional): The gradient scaler in use for mixed precision training.
algorithms (Algorithm | Sequence[Algorithm], optional): The algorithms used for training.
callbacks (Callback | Sequence[Callback], optional): The callbacks used for training.
deepspeed_config (Dict[str, Any], optional): The configuration dictionary for deepspeed.
fsdp_config (Dict[str, Any], optional): The configuration dictionary for FSDP.
Attributes:
batch (types.Batch): The batch. This will be the entire batch during the :attr:`.Event.AFTER_DATALOADER`, or a
microbatch between :attr:`.Event.BATCH_START` and :attr:`.Event.BATCH_END`.
device (Device): The device used by this process. The trainer moves the model and loaded data to this device. This
can be used in callbacks and algorithms to move data onto the correct device.
train_metrics (Dict[str, Metric]): The current train metrics, organized by metric name. ``train_metrics`` will be deep-copied to
ensure that each evaluator updates only its ``train_metrics``.
For example:
>>> trainer = Trainer(
... ...,
... train_dataloader=train_dataloader,
... eval_dataloader=eval_dataloader,
... )
>>> trainer.fit()
>>> trainer.state.train_metrics
{'MulticlassAccuracy': MulticlassAccuracy()}
eval_metrics (Dict[str, Dict[str, Metric]]): The current evaluation metrics, organized
by dataloader label and then by metric name. If not using an :class:`.Evaluator`,
the eval dataloader is labeled ``'eval'``. Otherwise, in the case of having multiple evaluation datasets,
the evaluator label is used. See the `Multiple Datasets Documentation <https://docs.mosaicml.com/en/stable/trainer/evaluation.html#multiple-datasets>`_
for more information. ``eval_metrics`` will be deep-copied to ensure that each evaluator updates only its ``eval_metrics``.
For example:
>>> from composer.metrics import CrossEntropy
>>> trainer = Trainer(
... ...,
... train_dataloader=train_dataloader,
... eval_dataloader=eval_dataloader,
... )
>>> trainer.fit()
>>> trainer.state.eval_metrics
{'eval': {'CrossEntropy': CrossEntropy(), 'MulticlassAccuracy': MulticlassAccuracy()}}
Or, when using an :class:`.Evaluator` for multiple evaluation datasets:
.. testsetup::
eval_1_dl = eval_dataloader
eval_2_dl = eval_dataloader
>>> from composer.core import Evaluator
>>> trainer = Trainer(
... ...,
... train_dataloader=train_dataloader,
... eval_dataloader=[
... Evaluator(label='eval1', dataloader=eval_1_dl, metric_names=['MulticlassAccuracy']),
... Evaluator(label='eval2', dataloader=eval_2_dl, metric_names=['MulticlassAccuracy']),
... ],
... )
>>> trainer.fit()
>>> trainer.state.eval_metrics
{'eval1': {'MulticlassAccuracy': MulticlassAccuracy()}, 'eval2': {'MulticlassAccuracy': MulticlassAccuracy()}}
eval_timestamp (Timestamp): The timestamp for the current evaluation dataloader. This timestamp is reset
before the dataloader is evaluated. The :attr:`~Timestamp.epoch` attribute for this timestamp is always
``0``.
grad_accum (int): The number of gradient accumulation steps per batch.
device_train_microbatch_size (int): The size of each train microbatch per device.
loss (torch.Tensor | Sequence[torch.Tensor]): The most recently computed loss.
model (torch.nn.Module): The training model.
.. note::
When using DeepSpeed or multi-rank training, the model will be wrapped with
:class:`~deepspeed.DeepSpeedEngine` or :class:`~torch.nn.parallel.DistributedDataParallel`,
respectively.
outputs (torch.Tensor | Sequence[torch.Tensor]): The most recently computed output from the model's forward
pass.
predict_timestamp (Timestamp): The timestamp for the current prediction dataloader. This timestamp is reset
before the dataloader is used. The :attr:`~Timestamp.epoch` attribute for this timestamp is always
``0``.
profiler (Profiler): The profiler (if profiling is enabled), or ``None`` if not profiling.
rank_zero_seed (int): The seed of the rank zero process.
run_name (str): The name for this training run.
scaler (torch.cuda.amp.GradScaler): The gradient scaler if using mixed-precision training, or
``None`` if not using mixed-precision training.
serialized_attributes (List[str]): The names of the attribute which are serialized in a checkpoint.
By default, the following attributes are serialized:
+-----------------------+-------------------------------------------------------------+
| Attribute | Description |
+=======================+=============================================================+
| model | The model under training. |
+-----------------------+-------------------------------------------------------------+
| optimizers | The optimizers being used to train the model. |
+-----------------------+-------------------------------------------------------------+
| schedulers | The learning rate schedulers. |
+-----------------------+-------------------------------------------------------------+
| algorithms | The algorithms used for training. |
+-----------------------+-------------------------------------------------------------+
| callbacks | The callbacks used for training. |
+-----------------------+-------------------------------------------------------------+
| scaler | The gradient scaler in use for mixed precision training. |
+-----------------------+-------------------------------------------------------------+
| timestamp | The timestamp that tracks training loop progress. |
+-----------------------+-------------------------------------------------------------+
| rank_zero_seed | The seed of the rank zero process. |
+-----------------------+-------------------------------------------------------------+
| train_metrics | The current training metrics |
+-----------------------+-------------------------------------------------------------+
| eval_metrics | The current evaluation metrics |
+-----------------------+-------------------------------------------------------------+
| run_name | The run name for training. |
+-----------------------+-------------------------------------------------------------+
| dataset_state | The dataset iteration state. |
+-----------------------+-------------------------------------------------------------+
timestamp (Timestamp): The current training timestamp.
"""
def __init__(
self,
# model
model: torch.nn.Module,
# determinism
rank_zero_seed: int,
# run_name
run_name: str,
# device
device: Device,
# stopping conditions
max_duration: Optional[Union[str, Time[int]]] = None,
# data configurations
grad_accum: Optional[int] = 1,
eval_batch_split: int = 1,
device_train_microbatch_size: Optional[int] = None,
auto_microbatching: bool = False,
using_device_microbatch_size: bool = True,
# dataloaders
train_dataloader: Optional[Iterable] = None,
evaluators: Optional[Union[Evaluator, Sequence[Evaluator]]] = None,
# these track the current 'active' dataloader
# depending on train, eval, or others
dataloader: Optional[Iterable] = None,
dataloader_label: Optional[str] = None,
dataloader_len: Union[int, Time[int]] = -1,
dataset_state: Optional[Dict[str, Any]] = None,
dataset_resumption: Optional[Dict[str, Any]] = None,
# precision
precision: Union[str, Precision] = Precision.FP32,
# optimizers
optimizers: Optional[Union[Optimizer, Sequence[Optimizer]]] = None,
# scaler
scaler: Optional[torch.cuda.amp.grad_scaler.GradScaler] = None,
# algorithms and callbacks
algorithms: Optional[Union[Algorithm, Sequence[Algorithm]]] = None,
callbacks: Optional[Union[Callback, Sequence[Callback]]] = None,
# Distributed training configs
deepspeed_config: Optional[Dict[str, Any]] = None,
fsdp_config: Optional[Dict[str, Any]] = None,
):
self.rank_zero_seed = rank_zero_seed
self.model = model
self.run_name = run_name
self.device = device
self.grad_accum = grad_accum
self.eval_batch_split = eval_batch_split
self.device_train_microbatch_size = device_train_microbatch_size
self.auto_microbatching = auto_microbatching
self.using_device_microbatch_size = using_device_microbatch_size
self._dataloader_len = None
self._dataloader = None
self._dataloader_label = None
self.set_dataloader(dataloader, dataloader_label, dataloader_len)
self.dataset_state = dataset_state
self.dataset_resumption = dataset_resumption or {}
self._max_duration = None
self.max_duration = max_duration
self._train_dataloader = train_dataloader
self._evaluators = list(ensure_tuple(evaluators))
self.timestamp = Timestamp()
self.eval_timestamp = Timestamp()
self.predict_timestamp = Timestamp()
self._precision = Precision(precision)
if optimizers is None:
self._optimizers = []
else:
self._optimizers = list(ensure_tuple(optimizers))
self._schedulers = []
self.scaler = scaler
self._algorithms = list(ensure_tuple(algorithms))
self._callbacks = list(ensure_tuple(callbacks))
self.profiler: Optional[Profiler] = None
self.deepspeed_config = deepspeed_config
self.fsdp_config = fsdp_config
self.fsdp_state_dict_type: Optional[str] = None
if self.fsdp_enabled:
if self.fsdp_config is not None:
self.fsdp_state_dict_type = self.fsdp_config.get('state_dict_type', 'full')
else:
self.fsdp_state_dict_type = 'full'
# Set defaults for transient variables (to make pyright happy)
self.batch: Any = None
self.loss: Union[torch.Tensor, Sequence[torch.Tensor]] = torch.Tensor()
self.outputs: Union[torch.Tensor, Sequence[torch.Tensor]] = torch.Tensor()
# These attributes will be serialized using .state_dict(), and loaded with .load_state_dict()
# All other attributes will not be serialized.
# For simplicity, omit the leading underscore for private attributes.
# For example, even though the optimizers are stored on the state
# as the "_optimizers" attribute, here we specify just "optimizers"
self.serialized_attributes = [
'model',
'optimizers',
'schedulers',
'algorithms',
'callbacks',
'scaler',
'timestamp',
'rank_zero_seed',
'train_metrics',
'eval_metrics',
'run_name',
'dataset_state',
]
self.train_metrics: Dict[str, Metric] = {}
self.eval_metrics: Dict[str, Dict[str, Metric]] = {}
self.train_metric_values: Dict[str, float] = {}
self.eval_metric_values: Dict[str, float] = {}
self.total_loss_dict: Dict[str, float] = {}
def _dataset_of(self, dataloader: Optional[Union[Evaluator, DataSpec, DataLoader, Iterable]]) -> Optional[Dataset]:
"""Get the dataset contained by the given dataloader-like object.
Args:
dataloader (Evaluator | DataSpec | DataLoader | Iterable, optional): The dataloader, wrapped dataloader, or
generic python iterable to get the dataset of, if applicable.
Returns:
Dataset: Its dataset, if there is one.
"""
from composer.core.evaluator import Evaluator
# If it's None, no dataset for you.
if dataloader is None:
return None
# An Evaluator is a dataloader wrapped with metrics. Unwrap its dataloader.
if isinstance(dataloader, Evaluator):
dataloader = dataloader.dataloader
# A DataSpec is a dataloader wrapped with an on-device transform. Unwrap its dataloader.
if isinstance(dataloader, DataSpec):
dataloader = dataloader.dataloader
# If what we now have is an actual DataLoader, return its dataset. If not, return None.
if isinstance(dataloader, DataLoader):
return dataloader.dataset
else:
return None
@property
def train_dataloader(self) -> Optional[Union[Iterable, DataLoader]]:
"""Get the train dataloader.
Returns:
Iterable | DataLoader, optional: The dataloader.
"""
return self._train_dataloader
@train_dataloader.setter
def train_dataloader(self, train_dataloader: Optional[Union[Iterable, DataLoader]]):
"""Set the train dataloader.
Args:
train_dataloader (Iterable | DataLoader, optional): The dataloader.
"""
self._train_dataloader = train_dataloader
# Load dataset state from checkpoint when train_dataloader is set
if self.dataset_state:
dataset = self._dataset_of(self._train_dataloader)
if hasattr(dataset, 'load_state_dict'):
dataset.load_state_dict(self.dataset_state['train']) # pyright: ignore
self.dataset_resumption['train'] = True
self.dataset_state['train'] = None
@property
def current_metrics(self):
warnings.warn(
'The ``current_metrics`` argument for a :class:`Trainer`. state is deprecated and will be removed in the future. Please use ``train_metrics`` and'
'``eval_metrics`` instead.')
return {'train': self.train_metrics, **self.eval_metrics}
@property
def seed(self):
"""The seed for the current rank."""
return self.rank_zero_seed + dist.get_global_rank()
@property
def max_duration(self):
"""The maximum training duration."""
return self._max_duration
@max_duration.setter
def max_duration(self, max_duration: Optional[Union[str, Time[int]]]):
if max_duration is None:
self._max_duration = None
return
if isinstance(max_duration, str):
max_duration = cast(Time[int], Time.from_timestring(max_duration))
if max_duration.unit == TimeUnit.DURATION:
raise ValueError('TimeUnit.DURATION is not allowed as a unit for max_duration')
self._max_duration = max_duration
[docs] def get_elapsed_duration(self) -> Optional[Time[float]]:
"""Get the elapsed training duration.
Returns:
Optional[Time[float]]: The elapsed duration, in :attr:`TimeUnit.DURATION`.
``Time(0.0, TimeUnit.DURATION)`` represents the beginning of training and ``Time(1.0, TimeUnit.DURATION)``
represents a completed training process. Returns ``None`` if ``max_duration`` is None.
"""
if self.max_duration is None:
return None
return self.timestamp.get(self.max_duration.unit) / self.max_duration
[docs] def stop_training(self):
"""Gracefully stop training.
The current batch of training will finish, and any scheduled evaluation,
logging, and evaluation for that batch, as well as any epoch end events.
"""
self.max_duration = self.timestamp.batch
@property
def optimizers(self):
"""The optimizers."""
return self._optimizers
@optimizers.setter
def optimizers(self, optimizers: Union[Optimizer, Sequence[Optimizer]]):
self._optimizers[:] = ensure_tuple(optimizers)
@property
def schedulers(self):
"""The schedulers."""
return self._schedulers
@schedulers.setter
def schedulers(self, schedulers: Union[types.PyTorchScheduler, Sequence[types.PyTorchScheduler]]):
self._schedulers[:] = ensure_tuple(schedulers)
[docs] def batch_get_item(self, key: Union[str, int, Callable, Any]) -> Any:
"""Gets element from batch either specified by key or user-specified function.
See batch_get in `utils/batch_helpers.py` for examples.
Args:
key (str | int | Tuple[Callable, Callable] | Any, optional): A key to index into the batch or a
user-specified function to do the extracting. A pair of callables is also
supported for cases where a get and set function pair are both passed
(like in Algorithms). The getter is assumed to be the first of the pair.
Returns:
The part of the batch specified by the key. This could be any type
depending on what the batch is composed of.
"""
return batch_get(self.batch, key)
[docs] def batch_set_item(self, key: Union[str, int, Callable, Any], value: Any):
"""Sets the element specified by the key of the set_fn to the specified value.
This is not an in-place operation, as for tuple-typed batches, a new batch object
must be created to modify them.
See batch_set in `utils/batch_helpers.py` for examples.
Args:
key (str | int | Tuple[Callable, Callable] | Any, optional): A key to index into the batch or a user-specified
function to do the setting. A pair of callables is also supported for
cases where a get and set function pair are both passed (like in
Algorithms). The setter is assumed to be the second of the pair.
value (Any): The value that batch[key] or batch.key gets set to or that the
user-defined set function sets a part of the batch to.
Returns:
batch (Any): The updated batch with value set at key.
"""
self.batch = batch_set(self.batch, key=key, value=value)
@property
def callbacks(self):
"""The callbacks."""
return self._callbacks
@callbacks.setter
def callbacks(self, callbacks: Sequence[Callback]):
self._callbacks[:] = callbacks
@property
def algorithms(self):
"""The algorithms."""
return self._algorithms
@algorithms.setter
def algorithms(self, algorithms: Sequence[Algorithm]):
self._algorithms[:] = algorithms
@property
def evaluators(self):
"""The evaluators."""
return self._evaluators
@evaluators.setter
def evaluators(self, evaluators: Union[Evaluator, Sequence[Evaluator]]):
self._evaluators[:] = list(ensure_tuple(evaluators))
# Load dataset state from checkpoint when evaluators are set
if self.dataset_state:
state = self.dataset_state['eval']
for evaluator in self._evaluators:
dataset = self._dataset_of(evaluator)
if hasattr(dataset, 'load_state_dict') and evaluator.label in state:
dataset.load_state_dict(state[evaluator.label]) # pyright: ignore
del self.dataset_state['eval']
@property
def deepspeed_enabled(self):
"""Indicates if deepspeed is enabled."""
return self.deepspeed_config is not None
@property
def fsdp_enabled(self):
"""Indicates if FSDP is enabled."""
if version.parse(torch.__version__) < version.parse('1.13.0'):
return False
from torch.distributed.fsdp import FullyShardedDataParallel
for module in self.model.modules():
if isinstance(module, FullyShardedDataParallel):
return True
return False
@property
def fsdp_sharded_state_dict_enabled(self):
if self.fsdp_config is None:
return False
return (self.fsdp_enabled and self.fsdp_state_dict_type in ['sharded', 'local'])
def _get_integrations_state_dict(self) -> Dict[str, Any]:
"""Gets a dictionary of information about integrations to store in the state dict.
This metadata is used for loading things from state dict that need to be done outside
of the normal Composer load path (e.g. HuggingFace model/tokenizer).
"""
from composer.models import HuggingFaceModel
integrations = {}
if isinstance(self.model, HuggingFaceModel):
integrations['huggingface'] = self.model.get_metadata()
return integrations
def _get_state_metadata(self) -> Dict[str, Any]:
"""Gets a dictionary of metadata to store in the state dict.
This metadata is used for checking compatibility between the current environment/setup
and the environment/setup that was used for the checkpoint that is being loaded in
"""
metadata_dict = {}
metadata_dict['composer_env_info'] = get_composer_env_dict()
metadata_dict['device'] = self.device.name
metadata_dict['precision'] = self.precision.value
metadata_dict['world_size'] = dist.get_world_size()
metadata_dict['device_train_microbatch_size'] = self.device_train_microbatch_size
if self._train_dataloader is not None and hasattr(self._train_dataloader, 'batch_size'):
metadata_dict['train_dataloader_batch_size'] = self._train_dataloader.batch_size # type: ignore
return metadata_dict
def _dataset_state_dict(self) -> Dict[str, Any]:
"""Collect the state dict(s) of our train and eval dataset(s).
Returns:
Dict[str, Any]: The state dict(s).
"""
obj = {
'train': None,
'eval': {},
}
dataset = self._dataset_of(self.train_dataloader)
if hasattr(dataset, 'state_dict'):
num_samples = int(self.timestamp.sample_in_epoch.value)
obj['train'] = dataset.state_dict(num_samples, True) # pyright: ignore
for evaluator in self.evaluators:
dataset = self._dataset_of(evaluator)
if hasattr(dataset, 'state_dict'):
# Don't save eval sample because we do not checkpoint during eval.
obj['eval'][evaluator.label] = dataset.state_dict(0, True) # pyright: ignore
return obj
[docs] def state_dict(self) -> Dict[str, Any]:
"""Collect the state dicts of our serializable attributes.
Returns:
Dict[str, Any]: The state dict.
"""
state_dict = {}
for attribute_name in self.serialized_attributes:
attribute_value = getattr(self, attribute_name)
if attribute_name == 'dataset_state':
serialized_value = self._dataset_state_dict()
elif attribute_name == 'model':
# Save model directly instead of by class name, since model may be wrapped by DistributedDataParallel
# If it is DDP wrapped, do not save the `module.` prefix, as that is an implementation detail
if self.fsdp_enabled and self.fsdp_state_dict_type is not None:
with fsdp_state_dict_type_context(attribute_value, state_dict_type=self.fsdp_state_dict_type):
model_state = attribute_value.state_dict()
else:
model_state = attribute_value.state_dict()
if self.is_model_ddp:
torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(model_state, 'module.')
serialized_value = model_state
elif attribute_name == 'optimizers':
optimizer = ensure_tuple(attribute_value)[
0] # Let's stop pretending. We don't support more than one optimizer.
if self.fsdp_enabled and self.fsdp_state_dict_type is not None:
optim_state_dict = {
type(optimizer).__qualname__:
fsdp_get_optim_state_dict(self.model, optimizer, state_dict_type=self.fsdp_state_dict_type)
}
else:
optim_state_dict = {type(optimizer).__qualname__: optimizer.state_dict()}
serialized_value = optim_state_dict
elif attribute_name == 'algorithms':
# Store as list to preserve order in which algorithms were applied
serialized_value = [(type(obj).__qualname__, obj.state_dict()) for obj in ensure_tuple(attribute_value)]
elif attribute_name in _STATE_DICT_SERIALIZED_ATTRIBUTES:
serialized_value = {type(obj).__qualname__: obj.state_dict() for obj in ensure_tuple(attribute_value)}
else:
serialized_value = attribute_value
state_dict[attribute_name] = serialized_value
state_dict['integrations'] = self._get_integrations_state_dict()
state_dict['metadata'] = self._get_state_metadata()
return state_dict
def _apply_required_algorithms(
self,
state_dict: Dict[str, Any],
logger: Logger,
exclude_algorithms: Optional[List[str]] = None,
algorithm_passes: Optional[List[AlgorithmPass]] = None,
):
"""Applies required algorithms which haven't been specified and aren't in the exclude list.
Args:
state_dict (Dict[str, Any]): State from checkpoint.
logger (Logger): Logger to use.
exclude_algorithms (List[str], optional): List of algorithm names to exclude. (default: ``None``)
algorithm_passes (List[AlgorithmPass], optional): A list of algorithm passes to apply to autoloaded algorithms
to sort them into the correct order. (default: ``None``)
"""
# Don't try to autoload on old checkpoints
if not isinstance(state_dict['algorithms'], list):
return
import composer.algorithms as algorithms # type: ignore imports used in `eval(representation)`
# Get repr of existing algorithms
current_algos = {}
for algo in self.algorithms:
if algo.required_on_load():
if type(algo) not in current_algos:
current_algos[type(algo)] = []
current_algos[type(algo)].append(algo.__repr__())
# Gather algorithms to apply
missing_algos = set()
missing_algo_names = []
missing_algo_reprs = []
for algo_name, serialized_value in state_dict['algorithms']:
# Check if required algorithm
if hasattr(algorithms, algo_name) and getattr(algorithms, algo_name).required_on_load():
# Check that algorithm is not explicitly excluded by user
if exclude_algorithms is None or algo_name not in exclude_algorithms:
try:
algo = eval(f"algorithms.{serialized_value['repr']}")
except:
warnings.warn(
textwrap.dedent(
f"required_on_load algorithm {serialized_value['repr']} was enabled when training the "
f'loaded checkpoint. Attempted to check its presence but recreating the algorithm '
"failed. This may be due to a change in the algorithm's API. If this required_on_load "
'algorithm is not properly specified, it may lead to unexpected behavior, including '
'failing to load weights for some layers.'))
continue
# Raise warning if we are unable to safely autoapply
if type(algo) in current_algos and not serialized_value['repr'] in current_algos[type(algo)]:
warnings.warn(
textwrap.dedent(
f"required_on_load algorithm {serialized_value['repr']} was enabled when training the "
f"loaded checkpoint but is now specified in the following forms: {', '.join(current_algos[type(algo)])}."
'Potential parameter discrepancies for this required_on_load algorithm may lead to '
'unexpected behavior, including failing to load weights for some layers.'))
# Otherwise, queue algorithm to be autoapplied
elif type(algo) not in current_algos:
missing_algos.add(algo)
missing_algo_names.append(algo_name)
missing_algo_reprs.append(serialized_value['repr'])
self.algorithms.append(algo)
# Reorder algorithms based on algorithm_passes from engine
algo_list = self.algorithms
if algorithm_passes is not None:
for algo_pass in algorithm_passes:
algo_list = algo_pass(algo_list, Event.INIT)
# Raise ValueError if algorithm_passes order any checkpoint algorithm before an already
# applied user specified algorithm
encountered_ckpt_algo = False
for algo in algo_list:
if algo in missing_algos:
encountered_ckpt_algo = True
elif encountered_ckpt_algo:
raise ValueError(
textwrap.dedent('The following algorithms were enabled when training this checkpoint '
f'and are required to successfully load it: {missing_algo_reprs}. '
'Attempted to autocreate and apply required algorithms, but at least one '
'of the loaded algorithms was ordered before a user specified algorithm '
'which has already been applied, preventing automatic application of '
'algorithms. If you wish to use pretrained weights and reinitialize '
'layers which have undergone surgery, the following algorithms may be '
'excluded using `load_exclude_algorithms`, e.g. '
f'`load_exclude_algorithms=[{missing_algo_names}]`.'))
try:
for algo in missing_algos: # TODO: use compiled algorithm order
if algo.match(Event.INIT, self):
algo.apply(Event.INIT, self, logger)
warnings.warn(
textwrap.dedent(
f'Automatically adding required_on_load algorithm {repr(algo)} to trainer, which was enabled '
'when training the loaded checkpoint. If you wish to use pretrained weights and ignore '
f'required_on_load algorithms, which may result in some weights failing to load, include {type(algo).__qualname__} '
f"in `load_exclude_algorithms`, e.g. `load_exclude_algorithms=['{type(algo).__qualname__}']`."))
except Exception as e:
raise ValueError(
textwrap.dedent(
'The following algorithms were enabled when training this checkpoint '
f'and are required to successfully load it: {missing_algo_reprs}. '
'Attempted to autocreate and apply required algorithms but an exception was '
'encountered. If you wish to use pretrained weights and reinitialize layers which '
'have undergone surgery, the following algorithms may be excluded using '
f'`load_exclude_algorithms`, e.g. `load_exclude_algorithms=[{missing_algo_names}]`.')) from e
[docs] def load_model_state(
self,
state_dict: Dict[str, Any],
logger: Logger,
strict: bool,
exclude_algorithms: Optional[List[str]] = None,
algorithm_passes: Optional[List[AlgorithmPass]] = None,
):
"""Loads the model's state from a ``state_dict``.
Args:
state_dict (Dict[str, Any]): The state dict, generated from a previous call to :meth:`state_dict`.
logger (Logger): The logger.
strict (bool): Whether the keys (i.e., model parameter names) in the model state dict should
perfectly match the keys in the model instance.
exclude_algorithms (List[str], optional): List of algorithm names to exclude from autoloading. (default: ``None``)
algorithm_passes (List[AlgorithmPass], optional): A list of algorithm passes to apply to autoloaded algorithms
to sort them into the correct order. (default: ``None``)
"""
if 'algorithms' in state_dict:
self._apply_required_algorithms(state_dict, logger, exclude_algorithms, algorithm_passes)
if state_dict.get('is_model_ddp', False) and not self.is_model_ddp:
# This check is for backwards compatibility, as pre-v0.6.0 checkpoints serialized the state
# with the `module.` prefix
torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(state_dict['model'], 'module.')
if self.fsdp_enabled and self.fsdp_state_dict_type is not None:
with fsdp_state_dict_type_context(self.model, state_dict_type=self.fsdp_state_dict_type):
missing_keys, unexpected_keys = self.model.load_state_dict(state_dict['model'], strict=strict)
else:
missing_keys, unexpected_keys = self.model.load_state_dict(state_dict['model'], strict=strict)
if len(missing_keys) > 0:
log.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}")
if len(unexpected_keys) > 0:
log.warning(f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}")
[docs] def load_optim_state(self, state_dict: Dict[str, Any]):
"""Load the optimizer state.
Args:
state_dict (Dict[str, Any]): The state to load.
"""
serialized_value = state_dict['optimizers']
for optimizer in ensure_tuple(self.optimizers):
if type(optimizer).__qualname__ not in serialized_value:
warnings.warn(
f'{type(optimizer).__qualname__} is not in the state_dict. Its state will not be restored.',
category=UserWarning)
continue
optim_state_dict = serialized_value[type(optimizer).__qualname__]
if self.fsdp_enabled:
log.debug(f'Loading FSDP optimizer with fsdp_state_dict_type={self.fsdp_state_dict_type}')
if self.fsdp_state_dict_type == 'sharded':
if version.parse(torch.__version__) < version.parse('1.13.0'):
raise RuntimeError('To use FSDP with Composer, you must use torch>=1.13.0.')
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
# Optimizer and optimizer state dict are already sharded, but not
# flattened, so we flatten the state dict then load it.
flattened_optim_state_dict = FSDP.flatten_sharded_optim_state_dict(
sharded_optim_state_dict=optim_state_dict, model=self.model, optim=optimizer)
optimizer.load_state_dict(flattened_optim_state_dict)
elif self.fsdp_state_dict_type == 'local':
# Optimizer and optimizer state dict are already sharded and flattened,
# so just load the state_dict.
optimizer.load_state_dict(optim_state_dict)
else: # fsdp_state_dict_type == 'full'
# FSDP enabled, but fsdp_state_dict is set to 'full', so the state dict
# is a full state dict and we must shard and flatten it first before loading it.
sharded_optim_state_dict = get_fsdp_sharded_optim_state_dict(full_optim_state_dict=optim_state_dict,
model=self.model)
log.debug(f'optimizer.load_state_dict call with fsdp_state_dict_type=full')
optimizer.load_state_dict(sharded_optim_state_dict)
# No FSDP, so just load the optim state dict.
else:
log.debug(f'Loading optimizer state dict')
optimizer.load_state_dict(optim_state_dict)
def _load_dataset_state(self, obj: Dict[str, Any]) -> None:
"""Load the dataset state.
Args:
obj (Dict[str, Any]): The state to load.
"""
self.dataset_state = obj
dataset = self._dataset_of(self.train_dataloader)
if hasattr(dataset, 'load_state_dict'):
dataset.load_state_dict(obj['train']) # pyright: ignore
obj['train'] = None
self.dataset_resumption['train'] = True
for evaluator in self.evaluators:
dataset = self._dataset_of(evaluator)
if hasattr(dataset, 'load_state_dict') and evaluator.label in obj['eval']:
dataset.load_state_dict(obj['eval'][evaluator.label]) # pyright: ignore
del obj['eval'][evaluator.label]
if 'eval' not in self.dataset_resumption:
self.dataset_resumption['eval'] = {}
# Note: We currently disable setting dataset_resumption for eval datasets,
# which means they have one sample fetched in _spin_dataloaders before training
# starts. This avoids "CUDA error: initialization error" -- its not clear why.
# self.dataset_resumption['eval'][evaluator.label] = True
[docs] def load_state_dict(
self,
state: Dict[str, Any],
logger: Logger,
strict: bool = False,
exclude_algorithms: Optional[List[str]] = None,
algorithm_passes: Optional[List[AlgorithmPass]] = None,
):
"""Loads the state.
Args:
state (Dict[str, Any]): object returned from call to :meth:`state_dict`.
logger (Logger): The logger.
strict (bool): whether the keys in the ``state["model"]`` should perfectly match the keys in the
``self.model``. Defaults to False.
exclude_algorithms (List[str], optional): List of algorithm names to exclude from autoloading. (default: ``None``)
algorithm_passes (List[AlgorithmPass], optional): A list of algorithm passes to apply to autoloaded algorithms
to sort them into the correct order. (default: ``None``)
"""
state = _ensure_backwards_compatible_checkpointing(state)
# Call load_model_state since it applies required algorithms
if 'model' in state:
self.load_model_state(
state,
logger,
strict=strict,
exclude_algorithms=exclude_algorithms,
algorithm_passes=algorithm_passes,
)
for attribute_name, serialized_value in state.items():
# Skip removed attributes as well as algorithms and model, which was already loaded
if attribute_name not in self.serialized_attributes or attribute_name == 'model':
continue
# Integrations are extra information about other libraries (e.g. huggingface) and not attributes to be loaded here
if attribute_name == 'integrations':
continue
# Skip metadata, which is not an attribute on State
if attribute_name == 'metadata':
continue
log.debug(f'Loading {attribute_name} into state.')
# Restructure algorithms serialized_value from list to dict
if attribute_name == 'algorithms' and isinstance(serialized_value, list):
serialized_value = {algo_name: algo_serialized for algo_name, algo_serialized in serialized_value}
if attribute_name == 'dataset_state':
self._load_dataset_state(serialized_value)
elif attribute_name == 'optimizers':
self.load_optim_state(state)
elif attribute_name == 'train_metrics':
state_field_value = getattr(self, attribute_name)
for metric_name, metric in serialized_value.items():
metric._device = self.device._device
state_field_value[metric_name] = metric
elif attribute_name == 'eval_metrics':
state_field_value = getattr(self, attribute_name)
for eval_key, eval_metrics in serialized_value.items():
for metric_name, metric in eval_metrics.items():
metric._device = self.device._device
state_field_value[eval_key][metric_name] = metric
elif attribute_name in _STATE_DICT_SERIALIZED_ATTRIBUTES:
state_field_value = getattr(self, attribute_name)
for target in ensure_tuple(state_field_value):
if type(target).__qualname__ not in serialized_value:
warnings.warn(
f'{type(target).__qualname__} is not in the state_dict. Its state will not be restored.',
category=UserWarning)
continue
source = serialized_value[type(target).__qualname__]
target.load_state_dict(source)
else:
# direct serialization
try:
setattr(self, attribute_name, serialized_value)
except AttributeError:
# ignore AttributeError for properties that have getters but not setters.
pass
@property
def dataloader(self):
"""The active dataloader."""
return self._dataloader
@property
def dataloader_label(self):
"""The dataloader label for the active dataloader.
By default, the training dataloader is called ``'train'``. The evaluator dataloader
is called ``'eval'``, or when multiple evaluators are used, the name of the evaluator.
However, the dataloader label can be explicitly specified in :meth:`.Trainer.fit`
and :meth:`.Trainer.eval`.
Returns:
Optional[str]: The dataloader label, or None if no dataloader is set.
"""
return self._dataloader_label
[docs] def set_dataloader(
self,
dataloader: Optional[Iterable] = None,
dataloader_label: Optional[str] = None,
dataloader_len: Union[int, Time[int]] = -1,
):
"""Update the active dataloader and dataloader label.
Args:
dataloader (Iterable, optional): The dataloader. Defaults to None.
dataloader_label (str, optional): The dataloader label. Must be ``None`` if and only if
``dataloader`` is None. Defaults to None.
dataloader_len (int, int): The number of batches per dataloader iteration (e.g. epoch), as used by the trainer.
Set to ``-1`` to iterate over the entire dataset. (Default: ``-1``.)
"""
if dataloader is None:
dataloader_label = None
else:
if dataloader_label is None:
raise ValueError('If the `dataloader` is specified, then `dataloader_label` must not be None.')
self._dataloader = dataloader
self._dataloader_label = dataloader_label
if dataloader is not None:
self.dataloader_len = dataloader_len # setting it to -1 will do a failsafe read of len(dataloader)
else:
self._dataloader_len = None
@property
def dataloader_len(self):
"""The number of batches per dataloader iteration (e.g. epoch), as used by the trainer.
.. note::
If not explicitly specified, this value is an approximation, as it depends on ``len(self.dataloader)``.
See the :doc:`PyTorch DataLoader Documentation <torch:data>` for more information.
Returns:
Optional[Time[int]]: The number of batches per dataloader iteration (e.g. epoch), or None if no dataloader
is defined or if the dataloader has an unknown length (e.g. streaming dataloaders).
"""
return self._dataloader_len
@dataloader_len.setter
def dataloader_len(self, num_batches: Union[int, Time[int]]):
if isinstance(num_batches, int):
num_batches = Time(num_batches, TimeUnit.BATCH)
if self._dataloader is None:
raise RuntimeError('`State.dataloader_len` cannot be set if the dataloader is not defined.')
try:
if isinstance(self._dataloader, collections.abc.Sized):
dataloader_len = len(self._dataloader)
else:
dataloader_len = None
except (TypeError, NotImplementedError):
dataloader_len = None
if dataloader_len is not None and num_batches >= 0 and int(num_batches) > dataloader_len:
warnings.warn((f'DataloaderNumBatchesWarning: The dataloader_len ({int(num_batches)}) '
f'is greater than the length (i.e. number of batches) of the dataloader, which is '
f'{dataloader_len}. State.dataloader_len is thus being set to {dataloader_len}.'))
self._dataloader_len = Time(dataloader_len, TimeUnit.BATCH)
return
if num_batches < 0:
if dataloader_len is not None:
# len(dataloader) is an approximation -- see https://pytorch.org/docs/stable/data.html.
# However, in the worst case where additional last batches are dropped, this calculation should be
# an over-estimate, leading to the entire dataloader still being iterated over.
self._dataloader_len = Time(dataloader_len, TimeUnit.BATCH)
else:
# The dataloader length is unknown.
self._dataloader_len = None
return
self._dataloader_len = num_batches
@property
def precision(self):
"""The numerical precision to use for training.
See :class:`~.Precision` for the supported precisions.
"""
return self._precision
@precision.setter
def precision(self, precision: Union[str, Precision]):
self._precision = Precision(precision)
@property
def is_model_ddp(self):
"""Whether :attr:`model` is an instance of a :class:`.DistributedDataParallel`."""
return isinstance(self.model, DistributedDataParallel)
@property
def deepspeed_model(self) -> deepspeed.DeepSpeedEngine:
"""Cast :attr:`model` to :class:`~deepspeed.DeepSpeedEngine`."""
if is_model_deepspeed(self.model):
return cast('deepspeed.DeepSpeedEngine', self.model)
raise TypeError('state.model is not a DeepSpeed model')