# Copyright 2021 MosaicML. All Rights Reserved.
"""Events represent specific points in the training loop where an :class:`~.core.Algorithm` and :class:`~.core.Callback`
can run."""
from composer.utils.string_enum import StringEnum
__all__ = ["Event"]
[docs]class Event(StringEnum):
"""Enum to represent events in the training loop.
The following pseudocode shows where each event fires in the training loop:
.. code-block:: python
# <INIT>
# <FIT_START>
for epoch in range(NUM_EPOCHS):
# <EPOCH_START>
for inputs, targets in dataloader:
# <AFTER_DATALOADER>
# <BATCH_START>
# <BEFORE_FORWARD>
outputs = model.forward(inputs)
# <AFTER_FORWARD>
# <BEFORE_LOSS>
loss = model.loss(outputs, targets)
# <AFTER_LOSS>
# <BEFORE_BACKWARD>
loss.backward()
# <AFTER_BACKWARD>
optimizer.step()
# <BATCH_END>
if should_eval(batch=True):
# <EVAL_START>
# <EVAL_BATCH_START>
# <EVAL_BEFORE_FORWARD>
# <EVAL_AFTER_FORWARD>
# <EVAL_BATCH_END>
# <EVAL_END>
# <BATCH_CHECKPOINT>
# <EPOCH_END>
if should_eval(batch=False):
# <EVAL_START>
# <EVAL_BATCH_START>
# <EVAL_BEFORE_FORWARD>
# <EVAL_AFTER_FORWARD>
# <EVAL_BATCH_END>
# <EVAL_END>
# <EPOCH_CHECKPOINT>
Attributes:
INIT: Invoked in the constructor of :class:`~.trainer.Trainer`. Model surgery (see
:mod:`~composer.utils.module_surgery`) typically occurs here.
FIT_START: Invoked at the beginning of each call to :meth:`.Trainer.fit`. Dataset transformations typically
occur here.
EPOCH_START: Start of an epoch.
BATCH_START: Start of a batch.
AFTER_DATALOADER: Immediately after the dataloader is called. Typically used for on-GPU dataloader transforms.
BEFORE_TRAIN_BATCH: Before the forward-loss-backward computation for a training batch. When using gradient
accumulation, this is still called only once.
BEFORE_FORWARD: Before the call to ``model.forward()``.
AFTER_FORWARD: After the call to ``model.forward()``.
BEFORE_LOSS: Before the call to ``model.loss()``.
AFTER_LOSS: After the call to ``model.loss()``.
BEFORE_BACKWARD: Before the call to ``loss.backward()``.
AFTER_BACKWARD: After the call to ``loss.backward()``.
AFTER_TRAIN_BATCH: After the forward-loss-backward computation for a training batch. When using gradient
accumulation, this event still fires only once.
BATCH_END: End of a batch, which occurs after the optimizer step and any gradient scaling.
BATCH_CHECKPOINT: After :attr:`.Event.BATCH_END` and any batch-wise evaluation. Saving checkpoints at this
event allows the checkpoint saver to use the results from any batch-wise evaluation to determine whether
a checkpoint should be saved.
EPOCH_END: End of an epoch.
EPOCH_CHECKPOINT: After :attr:`.Event.EPOCH_END` and any epoch-wise evaluation. Saving checkpoints at this event allows
event allows the checkpoint saver to use the results from any epoch-wise evaluation to determine whether
a checkpointshould be saved.
EVAL_START: Start of evaluation through the validation dataset.
EVAL_BATCH_START: Before the call to ``model.validate(batch)``
EVAL_BEFORE_FORWARD: Before the call to ``model.validate(batch)``
EVAL_AFTER_FORWARD: After the call to ``model.validate(batch)``
EVAL_BATCH_END: After the call to ``model.validate(batch)``
EVAL_END: End of evaluation through the validation dataset.
"""
INIT = "init"
FIT_START = "fit_start"
EPOCH_START = "epoch_start"
BATCH_START = "batch_start"
AFTER_DATALOADER = "after_dataloader"
BEFORE_TRAIN_BATCH = "before_train_batch"
BEFORE_FORWARD = "before_forward"
AFTER_FORWARD = "after_forward"
BEFORE_LOSS = "before_loss"
AFTER_LOSS = "after_loss"
BEFORE_BACKWARD = "before_backward"
AFTER_BACKWARD = "after_backward"
AFTER_TRAIN_BATCH = "after_train_batch"
BATCH_END = "batch_end"
BATCH_CHECKPOINT = "batch_checkpoint"
EPOCH_END = "epoch_end"
EPOCH_CHECKPOINT = "epoch_checkpoint"
EVAL_START = "eval_start"
EVAL_BATCH_START = "eval_batch_start"
EVAL_BEFORE_FORWARD = "eval_before_forward"
EVAL_AFTER_FORWARD = "eval_after_forward"
EVAL_BATCH_END = "eval_batch_end"
EVAL_END = "eval_end"
@property
def is_before_event(self) -> bool:
"""Whether the event is a 'before_*' event (e.g., :attr:`~Event.BEFORE_LOSS`) and has a corresponding 'after_*'
(.e.g., :attr:`~Event.AFTER_LOSS`)."""
return self in _BEFORE_EVENTS
@property
def is_after_event(self) -> bool:
"""Whether the event is an 'after_*' event (e.g., :attr:`~Event.AFTER_LOSS`) and has a corresponding 'before_*'
(.e.g., :attr:`~Event.BEFORE_LOSS`)."""
return self in _AFTER_EVENTS
@property
def canonical_name(self) -> str:
"""The name of the event, without before/after markers.
Events that have a corresponding "before" or "after" event share the same canonical name.
Example:
>>> Event.EPOCH_START.canonical_name
'epoch'
>>> Event.EPOCH_END.canonical_name
'epoch'
Returns:
str: The canonical name of the event.
"""
name: str = self.value
name = name.replace("before_", "")
name = name.replace("after_", "")
name = name.replace("_start", "")
name = name.replace("_end", "")
return name
_BEFORE_EVENTS = (Event.EPOCH_START, Event.BATCH_START, Event.BEFORE_TRAIN_BATCH, Event.BEFORE_FORWARD,
Event.BEFORE_LOSS, Event.BEFORE_BACKWARD, Event.EVAL_START, Event.EVAL_BATCH_START,
Event.EVAL_BEFORE_FORWARD)
_AFTER_EVENTS = (Event.EPOCH_END, Event.BATCH_END, Event.AFTER_TRAIN_BATCH, Event.AFTER_FORWARD, Event.AFTER_LOSS,
Event.AFTER_BACKWARD, Event.EVAL_END, Event.EVAL_BATCH_END, Event.EVAL_AFTER_FORWARD)