# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0
"""Training Loop Events."""
from composer.utils import StringEnum
__all__ = ['Event']
[docs]class Event(StringEnum):
"""Enum to represent training loop events.
Events mark specific points in the training loop where an :class:`~.core.Algorithm` and :class:`~.core.Callback`
can run.
The following pseudocode shows where each event fires in the training loop:
.. code-block:: python
# <INIT>
# <BEFORE_LOAD>
# <AFTER_LOAD>
# <FIT_START>
for iteration in range(NUM_ITERATIONS):
# <ITERATION_START>
for epoch in range(NUM_EPOCHS):
# <EPOCH_START>
while True:
# <BEFORE_DATALOADER>
batch = next(dataloader)
if batch is None:
break
# <AFTER_DATALOADER>
# <BATCH_START>
# <BEFORE_TRAIN_BATCH>
for microbatch in batch.split(device_train_microbatch_size):
# <BEFORE_FORWARD>
outputs = model(batch)
# <AFTER_FORWARD>
# <BEFORE_LOSS>
loss = model.loss(outputs, batch)
# <AFTER_LOSS>
# <BEFORE_BACKWARD>
loss.backward()
# <AFTER_BACKWARD>
# Un-scale gradients
# <AFTER_TRAIN_BATCH>
optimizer.step()
# <BATCH_END>
# <EVAL_BEFORE_ALL>
for eval_dataloader in eval_dataloaders:
if should_eval(batch=True):
# <EVAL_START>
for batch in eval_dataloader:
# <EVAL_BATCH_START>
# <EVAL_BEFORE_FORWARD>
outputs, targets = model(batch)
# <EVAL_AFTER_FORWARD>
metrics.update(outputs, targets)
# <EVAL_BATCH_END>
# <EVAL_END>
# <EVAL_AFTER_ALL>
# <BATCH_CHECKPOINT>
# <EPOCH_END>
# <BEFORE_EVAL_ALL>
for eval_dataloader in eval_dataloaders:
if should_eval(batch=True):
# <EVAL_START>
for batch in eval_dataloader:
# <EVAL_BATCH_START>
# <EVAL_BEFORE_FORWARD>
outputs, targets = model(batch)
# <EVAL_AFTER_FORWARD>
metrics.update(outputs, targets)
# <EVAL_BATCH_END>
# <EVAL_END>
# <AFTER_EVAL_ALL>
# <EPOCH_CHECKPOINT>
# <ITERATION_END>
# <ITERATION_CHECKPOINT>
# <FIT_END>
Attributes:
INIT: Invoked in the constructor of :class:`~.trainer.Trainer`. Model surgery (see
:mod:`~composer.utils.module_surgery`) typically occurs here.
BEFORE_LOAD: Immediately before the checkpoint is loaded in :class:`~.trainer.Trainer`.
AFTER_LOAD: Immediately after checkpoint is loaded in constructor of :class:`~.trainer.Trainer`.
FIT_START: Invoked at the beginning of each call to :meth:`.Trainer.fit`. Dataset transformations typically
occur here.
ITERATION_START: Start of an iteration.
EPOCH_START: Start of an epoch.
BEFORE_DATALOADER: Immediately before the dataloader is called.
AFTER_DATALOADER: Immediately after the dataloader is called. Typically used for on-GPU dataloader transforms.
BATCH_START: Start of a batch.
BEFORE_TRAIN_BATCH: Before the forward-loss-backward computation for a training batch. When using gradient
accumulation, this is still called only once.
BEFORE_FORWARD: Before the call to ``model.forward()``.
This is called multiple times per batch when using gradient accumulation.
AFTER_FORWARD: After the call to ``model.forward()``.
This is called multiple times per batch when using gradient accumulation.
BEFORE_LOSS: Before the call to ``model.loss()``.
This is called multiple times per batch when using gradient accumulation.
AFTER_LOSS: After the call to ``model.loss()``.
This is called multiple times per batch when using gradient accumulation.
BEFORE_BACKWARD: Before the call to ``loss.backward()``.
This is called multiple times per batch when using gradient accumulation.
AFTER_BACKWARD: After the call to ``loss.backward()``.
This is called multiple times per batch when using gradient accumulation.
AFTER_TRAIN_BATCH: After the forward-loss-backward computation for a training batch. When using gradient
accumulation, this event still fires only once.
BATCH_END: End of a batch, which occurs after the optimizer step and any gradient scaling.
BATCH_CHECKPOINT: After :attr:`.Event.BATCH_END` and any batch-wise evaluation. Saving checkpoints at this
event allows the checkpoint saver to use the results from any batch-wise evaluation to determine whether
a checkpoint should be saved.
EPOCH_END: End of an epoch.
EPOCH_CHECKPOINT: After :attr:`.Event.EPOCH_END` and any epoch-wise evaluation. Saving checkpoints at this
event allows the checkpoint saver to use the results from any epoch-wise evaluation to determine whether
a checkpoint should be saved.
ITERATION_END: End of an iteration.
ITERATION_CHECKPOINT: After :attr:`.Event.ITERATION_END`. Saving checkpoints at this event allows the checkpoint
saver to determine whether a checkpoint should be saved.
FIT_END: Invoked at the end of each call to :meth:`.Trainer.fit`. This event exists primarily for logging information
and flushing callbacks. Algorithms should not transform the training state on this event, as any changes will not
be preserved in checkpoints.
EVAL_BEFORE_ALL: Before any evaluators process validation dataset.
EVAL_START: Start of evaluation through the validation dataset.
EVAL_BATCH_START: Before the call to ``model.eval_forward(batch)``
EVAL_BEFORE_FORWARD: Before the call to ``model.eval_forward(batch)``
EVAL_AFTER_FORWARD: After the call to ``model.eval_forward(batch)``
EVAL_BATCH_END: After the call to ``model.eval_forward(batch)``
EVAL_END: End of evaluation through the validation dataset.
EVAL_AFTER_ALL: After all evaluators process validation dataset.
EVAL_STANDALONE_START: Start of evaluation through a direct call to `trainer.eval`.
EVAL_STANDALONE_END: End of evaluation through a direct call to `trainer.eval`.
"""
INIT = 'init'
BEFORE_LOAD = 'before_load'
AFTER_LOAD = 'after_load'
FIT_START = 'fit_start'
ITERATION_START = 'iteration_start'
EPOCH_START = 'epoch_start'
BEFORE_DATALOADER = 'before_dataloader'
AFTER_DATALOADER = 'after_dataloader'
BATCH_START = 'batch_start'
BEFORE_TRAIN_BATCH = 'before_train_batch'
BEFORE_FORWARD = 'before_forward'
AFTER_FORWARD = 'after_forward'
BEFORE_LOSS = 'before_loss'
AFTER_LOSS = 'after_loss'
BEFORE_BACKWARD = 'before_backward'
AFTER_BACKWARD = 'after_backward'
AFTER_TRAIN_BATCH = 'after_train_batch'
BATCH_END = 'batch_end'
BATCH_CHECKPOINT = 'batch_checkpoint'
EPOCH_END = 'epoch_end'
EPOCH_CHECKPOINT = 'epoch_checkpoint'
ITERATION_END = 'iteration_end'
ITERATION_CHECKPOINT = 'iteration_checkpoint'
FIT_END = 'fit_end'
EVAL_BEFORE_ALL = 'eval_before_all'
EVAL_START = 'eval_start'
EVAL_BATCH_START = 'eval_batch_start'
EVAL_BEFORE_FORWARD = 'eval_before_forward'
EVAL_AFTER_FORWARD = 'eval_after_forward'
EVAL_BATCH_END = 'eval_batch_end'
EVAL_END = 'eval_end'
EVAL_AFTER_ALL = 'eval_after_all'
EVAL_STANDALONE_START = 'eval_standalone_start'
EVAL_STANDALONE_END = 'eval_standalone_end'
PREDICT_START = 'predict_start'
PREDICT_BATCH_START = 'predict_batch_start'
PREDICT_BEFORE_FORWARD = 'predict_before_forward'
PREDICT_AFTER_FORWARD = 'predict_after_forward'
PREDICT_BATCH_END = 'predict_batch_end'
PREDICT_END = 'predict_end'
@property
def is_before_event(self) -> bool:
"""Whether the event is an "before" event.
An "before" event (e.g., :attr:`~Event.BEFORE_LOSS`) has a corresponding "after" event
(.e.g., :attr:`~Event.AFTER_LOSS`).
"""
return self in _BEFORE_EVENTS
@property
def is_after_event(self) -> bool:
"""Whether the event is an "after" event.
An "after" event (e.g., :attr:`~Event.AFTER_LOSS`) has a corresponding "before" event
(.e.g., :attr:`~Event.BEFORE_LOSS`).
"""
return self in _AFTER_EVENTS
@property
def canonical_name(self) -> str:
"""The name of the event, without before/after markers.
Events that have a corresponding "before" or "after" event share the same canonical name.
Example:
>>> Event.EPOCH_START.canonical_name
'epoch'
>>> Event.EPOCH_END.canonical_name
'epoch'
Returns:
str: The canonical name of the event.
"""
name: str = self.value
name = name.replace('before_', '')
name = name.replace('after_', '')
name = name.replace('_start', '')
name = name.replace('_end', '')
return name
@property
def is_predict(self) -> bool:
"""Whether the event is during the predict loop."""
return self.value.startswith('predict')
@property
def is_eval(self) -> bool:
"""Whether the event is during the eval loop."""
return self.value.startswith('eval')
_BEFORE_EVENTS = (
Event.BEFORE_LOAD,
Event.FIT_START,
Event.ITERATION_START,
Event.EPOCH_START,
Event.BEFORE_DATALOADER,
Event.BATCH_START,
Event.BEFORE_TRAIN_BATCH,
Event.BEFORE_FORWARD,
Event.BEFORE_LOSS,
Event.BEFORE_BACKWARD,
Event.EVAL_BEFORE_ALL,
Event.EVAL_START,
Event.EVAL_BATCH_START,
Event.EVAL_BEFORE_FORWARD,
Event.PREDICT_START,
Event.PREDICT_BATCH_START,
Event.PREDICT_BEFORE_FORWARD,
Event.EVAL_STANDALONE_START,
)
_AFTER_EVENTS = (
Event.AFTER_LOAD,
Event.ITERATION_END,
Event.EPOCH_END,
Event.BATCH_END,
Event.AFTER_DATALOADER,
Event.AFTER_TRAIN_BATCH,
Event.AFTER_FORWARD,
Event.AFTER_LOSS,
Event.AFTER_BACKWARD,
Event.EVAL_AFTER_ALL,
Event.EVAL_END,
Event.EVAL_BATCH_END,
Event.EVAL_AFTER_FORWARD,
Event.FIT_END,
Event.PREDICT_END,
Event.PREDICT_BATCH_END,
Event.PREDICT_AFTER_FORWARD,
Event.EVAL_STANDALONE_END,
)