# Copyright 2021 MosaicML. All Rights Reserved.
"""Engine is a coordinator for running algorithms and resolving ordering conflicts among them for composition.
.. currentmodule:: composer
The order in which algorithms are run matters significantly during composition. For example,
:class:`~.SelectiveBackprop` algorithm runs on the :attr:`~.Event.AFTER_DATALOADER` event and must run before any data
augmentations. :class:`~.engine.Engine` runs re-ordering passes to resolve such ordering issues or conflicts.
.. note::
* An instance of :class:`~.engine.Engine` is automatically constructed by the :class:`~.trainer.Trainer`
constructor. A user need not instantiate :class:`~.engine.Engine` class.
* The design of :class:`~.engine.Engine` is subject to change in future releases to accommodate more complexity as
we investigate composition of algorithms.
Currently, the following passes are registered:
* **LIFO order for events**
For the events that follow the ``before_*`` (e.g., :attr:`~.Event.BEFORE_LOSS`) and ``after_*`` (e.g.,
:attr:`~.Event.AFTER_LOSS`) pattern, the ordering of algorithms is reversed for the ``after_*`` events. For example,
four given algorithms ``A``, ``B``, ``C`` and ``D`` will run in ``ABCD`` ordering on the ``before_*`` event while
``DCBA`` ordering on the ``after_*`` event.
This allows algorithms to "clean up" their changes. For example, :class:`~.LabelSmoothing` will smooth the labels
upon on :attr:`~.Event.BEFORE_LOSS` event and then restore the original unsmoothed labels on
:attr:`~.Event.AFTER_LOSS` event.
* **Run Selective Backprop first**
:class:`~.SelectiveBackprop` runs after the dataloader returns the batch, and executes an extra forward pass to rank
and prune the examples in the batch by loss. To ensure a clean estimate of loss, :class:`~.SelectiveBackprop` should
run before any other data augmentations (e.g., :class:`~.MixUp`) on the :attr:`~.Event.AFTER_DATALOADER` event.
Trace
~~~~~
Traces record whether an algorithm ran at a particular step and event combination and also the order of such executions.
These are logged with the key ``<algorithm_name>/<event>``.
For example, the algorithm :class:`~.LayerFreezing`, which runs at the end of every epoch on :attr:`~.Event.EPOCH_END`,
will emit a series of traces:
.. code-block::
[STEP=0][layer_freezing/INIT=0]
[STEP=1][layer_freezing/EPOCH_START=0]
[STEP=1][layer_freezing/BATCH_START=0]
...
[STEP=2][layer_freezing/BATCH_START=0]
...
[STEP=3][layer_freezing/BATCH_START=0]
...
[STEP=3][layer_freezing/EPOCH_END=1] # <-- layer freezing ran on step 3 here!
"""
from __future__ import annotations
import contextlib
import logging
from collections import OrderedDict
from dataclasses import dataclass
from typing import ContextManager, Dict, Optional, Sequence, Union, cast
from composer.core.algorithm import Algorithm
from composer.core.callback import Callback
from composer.core.event import Event
from composer.core.state import State
from composer.loggers import Logger, LogLevel
from composer.profiler import ProfilerAction
log = logging.getLogger(__name__)
__all__ = ["Trace", "Engine", "Traces"]
#: The default traces of an entire run is an OrderedDict.
#: The keys are of format ``<algorithm_name>/<event>`` (e.g., ``Blurpool/INIT``) and values are an instance of
#: :class:`Trace`.
Traces = Dict[str, "Trace"]
_ALWAYS_RECORD_EVENTS = [Event.INIT, Event.FIT_START, Event.EPOCH_START, Event.EPOCH_END]
[docs]@dataclass
class Trace():
"""Record of an algorithm's execution.
Attributes:
exit_code (int or None): Optional return value from an algorithm. Default: None.
order (int or None): Order in which the algorithm was executed
in the list of algorithms. None means algorithm was not run.
run (bool): Whether the algorithm was run. Default: False
"""
exit_code: Optional[int] = None
order: Optional[int] = None
run: bool = False
def _setup_trace(algorithms: Sequence[Algorithm], event: Event) -> Traces:
"""The default traces of an entire run is an OrderedDict.
The keys are of format ``<algorithm_name>/<event>`` (e.g., ``Blurpool/INIT``) and values are an instance of
:class:`Trace`.
"""
return OrderedDict([(f'{algo}/{event}', Trace()) for algo in algorithms])
[docs]class Engine():
"""Coordinator for running algorithms and resolving ordering conflicts among them for composition.
Args:
state (State): The initial :class:`~.state.State` of the trainer. ``state`` will be modified in-place.
logger (Logger): A :class:`~.logger.Logger` instance to be used for logging algorithm and callback
specific metrics.
"""
def __init__(self, state: State, logger: Logger):
self.logger = logger
self.state = state
[docs] def run_event(
self,
event: Union[Event, str],
) -> Traces:
"""Runs the sequence of algorithms and callbacks (see :class:`~.callback.Callback`).
Filters algorithms by calling each one's :meth:`~.Algorithm.match` method, internally checks for conflicting
algorithms, then runs each algorithm's :meth:`~.Algorithm.apply` method to make in-place changes to the
``state``.
The default order of execution for algorithms is determined by the provided list. However, :class:`Engine` makes
changes to this order internally to resolve ordering conflicts.
Returns :data:`Traces` of the execution, a dictionary with keys formatted as ``<algorithm_name>/<event>`` (e.g.,
``Blurpool/INIT``), and values are an instance of :class:`~.engine.Trace`.
Callbacks are always run after algorithms and do not return a trace.
This method can be called with either the :class:`~.event.Event` enum member values or a string of the event
name.
Examples:
>>> engine = Engine(state, logger)
>>> engine.run_event(Event.BEFORE_LOSS)
OrderedDict()
>>> # calling with a string of the event name also works
>>> engine.run_event('before_loss')
OrderedDict()
Args:
event (Event or str): The current :class:`~.event.Event`. It can be the enum member values or a
string with the event value.
Returns:
traces (Traces): Ordered dictionary of trace for each algorithm.
"""
duration_marker = None
event = Event(event)
if self.state.profiler is not None:
name = f"event/{event.canonical_name}"
if (event.is_before_event or event.is_after_event):
# if not part of an event pair (e.g. init or after dataloader), then don't record an event here
if event in _ALWAYS_RECORD_EVENTS:
actions = [ProfilerAction.ACTIVE, ProfilerAction.WARMUP, ProfilerAction.SKIP]
else:
actions = [ProfilerAction.ACTIVE, ProfilerAction.WARMUP]
duration_marker = self.state.profiler.marker(name, actions=actions)
if event.is_after_event and duration_marker is not None:
duration_marker.finish()
if event == Event.INIT:
# For the INIT event, run the callbacks first to initialize the loggers
# For other events, run the algorithms first, so the callbacks have the state
# after algorithms modify it
self._run_callbacks(event)
traces = self._run_algorithms(event)
else:
traces = self._run_algorithms(event)
self._run_callbacks(event)
if event.is_before_event and duration_marker is not None:
duration_marker.start()
return traces
def _run_algorithms(
self,
event: Event,
) -> Traces:
algorithms_to_run = [algo for algo in self.state.algorithms if algo.match(event, self.state)]
# future collision resolution
algorithms_to_run = self._compile(algorithms_to_run, event)
trace = _setup_trace(algorithms_to_run, event)
for order, algorithm in enumerate(algorithms_to_run):
marker = None
if self.state.profiler is not None:
marker = self.state.profiler.marker(f"algorithm/{algorithm.__class__.__name__}/event/{event.value}",
categories=[
event.value,
algorithm.__class__.__name__,
])
ctx = cast(ContextManager, contextlib.nullcontext()) if marker is None else marker
with ctx:
exit_code = algorithm.apply(event, self.state, self.logger)
trace_key = f'{algorithm}/{event}'
trace[trace_key] = Trace(exit_code=exit_code, order=order, run=True)
if self.logger is not None:
if event in (Event.INIT, Event.FIT_START):
log_level = LogLevel.FIT
if event in (Event.EPOCH_START, Event.EPOCH_END):
log_level = LogLevel.EPOCH
else:
# algs don't run on eval events, so don't have to worry about
# batch-frequency vs epoch-frequency evaluators
log_level = LogLevel.BATCH
if len(trace) > 0:
self.logger.data(log_level=log_level, data={key: 1 if tr.run else 0 for key, tr in trace.items()})
return trace
def _compile(
self,
algorithms_to_run: Sequence[Algorithm],
event: Event,
) -> Sequence[Algorithm]:
"""Runs compilation passes that modify the order and content of a list of algorithms.
Currently, runs the algorithms in a FILO queue for the before_ and after_ events. For example,
algorithms will run in order ABCD during before_loss, and in DCBA during after_loss. The motivation
here is that algorithms can 'undo' their effects upon the exit of an event. Note that events that
have the pattern _start or _end will still run with ABCD order.
Intent of this method is to eventually store and handle other algorithms collisions and ordering
requirements.
Args:
algorithms_to_run(Sequence[Algorithm]): Sequence of algorithms
event (Event): The current event
Returns:
algorithms_to_run(Sequence[Algorithm]): Modified sequence of algorithms
"""
from composer.algorithms import SelectiveBackprop, StochasticDepth
# Move selective backprop to the beginning while maintaining order of other algorithms
algorithms = sorted(algorithms_to_run,
key=lambda x: not isinstance(x, SelectiveBackprop) and not isinstance(x, StochasticDepth))
if event.is_after_event:
"""Establish a FILO queue of algorithms before_ and after_ an event.
before_loss: A, B, C, D
after_loss: D, C, B, A
"""
algorithms = list(reversed(algorithms))
return algorithms
def _run_callbacks(
self,
event: Union[Event, str],
):
"""Runs a sequence of callbacks by calling the function for an event.
Args:
event (Event): The current :class:`~.event.Event`
Returns:
None
"""
event = Event(event)
for cb in self.state.callbacks:
marker = None
if self.state.profiler is not None:
marker = self.state.profiler.marker(f"callback/{cb.__class__.__name__}/event/{event.value}",
categories=[
event.value,
cb.__class__.__name__,
])
ctx = cast(ContextManager, contextlib.nullcontext()) if marker is None else marker
with ctx:
cb.run_event(event, self.state, self.logger)
[docs] def close(self) -> None:
"""Invokes :meth:`~.Callback.close` and :meth:`~.Callback.post_close` for each callback.
:meth:`~.Callback.close` is invoked for each callback. For all callbacks where :meth:`~.Callback.close` did not
raise an exception, then :meth:`~.Callback.post_close` is invoked.
This method does not re-raise any exceptions from :meth:`~.Callback.close` and :meth:`~.Callback.post_close`.
Instead, these exceptions are logged to the :class:`~.logger.Logger`.
"""
callback_to_has_exception: Dict[Callback, bool] = {}
for callback in self.state.callbacks:
try:
callback.close(self.state, self.logger)
except Exception as e:
log.error(
f"Error running {callback.__class__.__name__}.close(). Skipping {callback.__class__.__name__}.post_close().",
exc_info=e,
stack_info=True)
callback_to_has_exception[callback] = True
else:
callback_to_has_exception[callback] = False
for callback in self.state.callbacks:
if callback_to_has_exception[callback] is False:
try:
callback.post_close()
except Exception as e:
log.error(f"Error running {callback.__class__.__name__}.post_close().", exc_info=e, stack_info=True)