Source code for composer.core.event

# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Training Loop Events."""
from composer.utils import StringEnum

__all__ = ['Event']


[docs]class Event(StringEnum): """Enum to represent training loop events. Events mark specific points in the training loop where an :class:`~.core.Algorithm` and :class:`~.core.Callback` can run. The following pseudocode shows where each event fires in the training loop: .. code-block:: python # <INIT> # <BEFORE_LOAD> # <AFTER_LOAD> # <FIT_START> for epoch in range(NUM_EPOCHS): # <EPOCH_START> while True: # <BEFORE_DATALOADER> batch = next(dataloader) if batch is None: break # <AFTER_DATALOADER> # <BATCH_START> # <BEFORE_TRAIN_BATCH> for microbatch in batch.split(device_train_microbatch_size): # <BEFORE_FORWARD> outputs = model(batch) # <AFTER_FORWARD> # <BEFORE_LOSS> loss = model.loss(outputs, batch) # <AFTER_LOSS> # <BEFORE_BACKWARD> loss.backward() # <AFTER_BACKWARD> # Un-scale gradients # <AFTER_TRAIN_BATCH> optimizer.step() # <BATCH_END> # <BEFORE_EVAL_ALL> for eval_dataloader in eval_dataloaders: if should_eval(batch=True): # <EVAL_START> for batch in eval_dataloader: # <EVAL_BATCH_START> # <EVAL_BEFORE_FORWARD> outputs, targets = model(batch) # <EVAL_AFTER_FORWARD> metrics.update(outputs, targets) # <EVAL_BATCH_END> # <EVAL_END> # <AFTER_EVAL_ALL> # <BATCH_CHECKPOINT> # <EPOCH_END> # <BEFORE_EVAL_ALL> for eval_dataloader in eval_dataloaders: if should_eval(batch=True): # <EVAL_START> for batch in eval_dataloader: # <EVAL_BATCH_START> # <EVAL_BEFORE_FORWARD> outputs, targets = model(batch) # <EVAL_AFTER_FORWARD> metrics.update(outputs, targets) # <EVAL_BATCH_END> # <EVAL_END> # <AFTER_EVAL_ALL> # <EPOCH_CHECKPOINT> # <FIT_END> Attributes: INIT: Invoked in the constructor of :class:`~.trainer.Trainer`. Model surgery (see :mod:`~composer.utils.module_surgery`) typically occurs here. BEFORE_LOAD: Immediately before the checkpoint is loaded in :class:`~.trainer.Trainer`. AFTER_LOAD: Immediately after checkpoint is loaded in constructor of :class:`~.trainer.Trainer`. FIT_START: Invoked at the beginning of each call to :meth:`.Trainer.fit`. Dataset transformations typically occur here. EPOCH_START: Start of an epoch. BEFORE_DATALOADER: Immediately before the dataloader is called. AFTER_DATALOADER: Immediately after the dataloader is called. Typically used for on-GPU dataloader transforms. BATCH_START: Start of a batch. BEFORE_TRAIN_BATCH: Before the forward-loss-backward computation for a training batch. When using gradient accumulation, this is still called only once. BEFORE_FORWARD: Before the call to ``model.forward()``. This is called multiple times per batch when using gradient accumulation. AFTER_FORWARD: After the call to ``model.forward()``. This is called multiple times per batch when using gradient accumulation. BEFORE_LOSS: Before the call to ``model.loss()``. This is called multiple times per batch when using gradient accumulation. AFTER_LOSS: After the call to ``model.loss()``. This is called multiple times per batch when using gradient accumulation. BEFORE_BACKWARD: Before the call to ``loss.backward()``. This is called multiple times per batch when using gradient accumulation. AFTER_BACKWARD: After the call to ``loss.backward()``. This is called multiple times per batch when using gradient accumulation. AFTER_TRAIN_BATCH: After the forward-loss-backward computation for a training batch. When using gradient accumulation, this event still fires only once. BATCH_END: End of a batch, which occurs after the optimizer step and any gradient scaling. BATCH_CHECKPOINT: After :attr:`.Event.BATCH_END` and any batch-wise evaluation. Saving checkpoints at this event allows the checkpoint saver to use the results from any batch-wise evaluation to determine whether a checkpoint should be saved. EPOCH_END: End of an epoch. EPOCH_CHECKPOINT: After :attr:`.Event.EPOCH_END` and any epoch-wise evaluation. Saving checkpoints at this event allows the checkpoint saver to use the results from any epoch-wise evaluation to determine whether a checkpointshould be saved. FIT_END: Invoked at the end of each call to :meth:`.Trainer.fit`. This event exists primarily for logging information and flushing callbacks. Algorithms should not transform the training state on this event, as any changes will not be preserved in checkpoints. EVAL_BEFORE_ALL: Before any evaluators process validation dataset. EVAL_START: Start of evaluation through the validation dataset. EVAL_BATCH_START: Before the call to ``model.eval_forward(batch)`` EVAL_BEFORE_FORWARD: Before the call to ``model.eval_forward(batch)`` EVAL_AFTER_FORWARD: After the call to ``model.eval_forward(batch)`` EVAL_BATCH_END: After the call to ``model.eval_forward(batch)`` EVAL_END: End of evaluation through the validation dataset. EVAL_AFTER_ALL: After all evaluators process validation dataset. EVAL_STANDALONE_START: Start of evaluation through a direct call to `trainer.eval`. EVAL_STANDALONE_END: End of evaluation through a direct call to `trainer.eval`. """ INIT = 'init' BEFORE_LOAD = 'before_load' AFTER_LOAD = 'after_load' FIT_START = 'fit_start' EPOCH_START = 'epoch_start' BEFORE_DATALOADER = 'before_dataloader' AFTER_DATALOADER = 'after_dataloader' BATCH_START = 'batch_start' BEFORE_TRAIN_BATCH = 'before_train_batch' BEFORE_FORWARD = 'before_forward' AFTER_FORWARD = 'after_forward' BEFORE_LOSS = 'before_loss' AFTER_LOSS = 'after_loss' BEFORE_BACKWARD = 'before_backward' AFTER_BACKWARD = 'after_backward' AFTER_TRAIN_BATCH = 'after_train_batch' BATCH_END = 'batch_end' BATCH_CHECKPOINT = 'batch_checkpoint' EPOCH_END = 'epoch_end' EPOCH_CHECKPOINT = 'epoch_checkpoint' FIT_END = 'fit_end' EVAL_BEFORE_ALL = 'eval_before_all' EVAL_START = 'eval_start' EVAL_BATCH_START = 'eval_batch_start' EVAL_BEFORE_FORWARD = 'eval_before_forward' EVAL_AFTER_FORWARD = 'eval_after_forward' EVAL_BATCH_END = 'eval_batch_end' EVAL_END = 'eval_end' EVAL_AFTER_ALL = 'eval_after_all' EVAL_STANDALONE_START = 'eval_standalone_start' EVAL_STANDALONE_END = 'eval_standalone_end' PREDICT_START = 'predict_start' PREDICT_BATCH_START = 'predict_batch_start' PREDICT_BEFORE_FORWARD = 'predict_before_forward' PREDICT_AFTER_FORWARD = 'predict_after_forward' PREDICT_BATCH_END = 'predict_batch_end' PREDICT_END = 'predict_end' @property def is_before_event(self) -> bool: """Whether the event is an "before" event. An "before" event (e.g., :attr:`~Event.BEFORE_LOSS`) has a corresponding "after" event (.e.g., :attr:`~Event.AFTER_LOSS`). """ return self in _BEFORE_EVENTS @property def is_after_event(self) -> bool: """Whether the event is an "after" event. An "after" event (e.g., :attr:`~Event.AFTER_LOSS`) has a corresponding "before" event (.e.g., :attr:`~Event.BEFORE_LOSS`). """ return self in _AFTER_EVENTS @property def canonical_name(self) -> str: """The name of the event, without before/after markers. Events that have a corresponding "before" or "after" event share the same canonical name. Example: >>> Event.EPOCH_START.canonical_name 'epoch' >>> Event.EPOCH_END.canonical_name 'epoch' Returns: str: The canonical name of the event. """ name: str = self.value name = name.replace('before_', '') name = name.replace('after_', '') name = name.replace('_start', '') name = name.replace('_end', '') return name @property def is_predict(self) -> bool: """Whether the event is during the predict loop.""" return self.value.startswith('predict') @property def is_eval(self) -> bool: """Whether the event is during the eval loop.""" return self.value.startswith('eval')
_BEFORE_EVENTS = (Event.BEFORE_LOAD, Event.FIT_START, Event.EPOCH_START, Event.BEFORE_DATALOADER, Event.BATCH_START, Event.BEFORE_TRAIN_BATCH, Event.BEFORE_FORWARD, Event.BEFORE_LOSS, Event.BEFORE_BACKWARD, Event.EVAL_BEFORE_ALL, Event.EVAL_START, Event.EVAL_BATCH_START, Event.EVAL_BEFORE_FORWARD, Event.PREDICT_START, Event.PREDICT_BATCH_START, Event.PREDICT_BEFORE_FORWARD, Event.EVAL_STANDALONE_START) _AFTER_EVENTS = (Event.AFTER_LOAD, Event.EPOCH_END, Event.BATCH_END, Event.AFTER_DATALOADER, Event.AFTER_TRAIN_BATCH, Event.AFTER_FORWARD, Event.AFTER_LOSS, Event.AFTER_BACKWARD, Event.EVAL_AFTER_ALL, Event.EVAL_END, Event.EVAL_BATCH_END, Event.EVAL_AFTER_FORWARD, Event.FIT_END, Event.PREDICT_END, Event.PREDICT_BATCH_END, Event.PREDICT_AFTER_FORWARD, Event.EVAL_STANDALONE_END)