Source code for composer.core.callback

# Copyright 2021 MosaicML. All Rights Reserved.

"""Base module for callbacks."""
from __future__ import annotations

import abc
from typing import TYPE_CHECKING

from composer.core.serializable import Serializable

if TYPE_CHECKING:
    from composer import Event, State
    from composer.loggers import Logger

__all__ = ["Callback"]


[docs]class Callback(Serializable, abc.ABC): """Base class for callbacks. Callbacks provide hooks that can run at each training loop :class:`~.event.Event`. A callback is similar to an :class:`~.algorithm.Algorithm` in that they are run on specific events. Callbacks differ from :class:`~.algorithm.Algorithm` in that they do not modify the training of the model. By convention, callbacks should not modify the :class:`~.state.State`. They are typically used to for non-essential recording functions such as logging or timing. Callbacks can be implemented in two ways: #. Override the individual methods named for each :class:`~.event.Event`. For example, .. doctest:: >>> class MyCallback(Callback): ... def epoch_start(self, state: State, logger: Logger): ... print(f'Epoch: {int(state.timer.epoch)}') >>> # construct trainer object with your callback >>> trainer = Trainer( ... model=model, ... train_dataloader=train_dataloader, ... eval_dataloader=eval_dataloader, ... optimizers=optimizer, ... max_duration="1ep", ... callbacks=[MyCallback()], ... ) >>> # trainer will run MyCallback whenever the EPOCH_START >>> # is triggered, like this: >>> _ = trainer.engine.run_event(Event.EPOCH_START) Epoch: 0 .. testcleanup:: trainer.engine.close() #. Override :meth:`run_event` if you want a single method to handle all events. If this method is overridden, then the individual methods corresponding to each event name (such as :meth:`epoch_start`) will no longer be automatically invoked. For example, if you override :meth:`run_event` then :meth:`epoch_start` will not be called on the :attr:`~.Event.EPOCH_START` event, :meth:`batch_start` will not be called on the :attr:`~.Event.BATCH_START` etc. However, you can invoke :meth:`epoch_start`, :meth:`batch_start` etc. in your overriding implementation of :meth:`run_event`. For example, .. doctest:: >>> class MyCallback(Callback): ... def run_event(self, event: Event, state: State, logger: Logger): ... if event == Event.EPOCH_START: ... print(f'Epoch: {int(state.timer.epoch)}') >>> # construct trainer object with your callback >>> trainer = Trainer( ... model=model, ... train_dataloader=train_dataloader, ... eval_dataloader=eval_dataloader, ... optimizers=optimizer, ... max_duration="1ep", ... callbacks=[MyCallback()], ... ) >>> # trainer will run MyCallback whenever the EPOCH_START >>> # is triggered, like this: >>> _ = trainer.engine.run_event(Event.EPOCH_START) Epoch: 0 .. testcleanup:: trainer.engine.close() """
[docs] def run_event(self, event: Event, state: State, logger: Logger) -> None: """This method is called by the engine on each event. Args: event (Event): The event. state (State): The state. logger (Logger): The logger. """ event_cb = getattr(self, event.value) return event_cb(state, logger)
[docs] def init(self, state: State, logger: Logger) -> None: """Called on the :attr:`~.Event.INIT` event. Args: state (State): The global state. logger (Logger): The logger. """ del state, logger # unused pass
[docs] def fit_start(self, state: State, logger: Logger) -> None: """Called on the :attr:`~.Event.FIT_START` event. Args: state (State): The global state. logger (Logger): The logger. """ del state, logger # unused pass
[docs] def epoch_start(self, state: State, logger: Logger) -> None: """Called on the :attr:`~.Event.EPOCH_START` event. Args: state (State): The global state. logger (Logger): The logger. """ del state, logger # unused pass
[docs] def batch_start(self, state: State, logger: Logger) -> None: """Called on the :attr:`~.Event.BATCH_START` event. Args: state (State): The global state. logger (Logger): The logger. """ del state, logger # unused pass
[docs] def after_dataloader(self, state: State, logger: Logger) -> None: """Called on the :attr:`~.Event.AFTER_DATALOADER` event. Args: state (State): The global state. logger (Logger): The logger. """ del state, logger # unused pass
[docs] def before_train_batch(self, state: State, logger: Logger) -> None: """Called on the :attr:`~.Event.BEFORE_TRAIN_BATCH` event. Args: state (State): The global state. logger (Logger): The logger. """ del state, logger # unused pass
[docs] def before_forward(self, state: State, logger: Logger) -> None: """Called on the :attr:`~.Event.BEFORE_FORWARD` event. Args: state (State): The global state. logger (Logger): The logger. """ del state, logger # unused pass
[docs] def after_forward(self, state: State, logger: Logger) -> None: """Called on the :attr:`~.Event.AFTER_FORWARD` event. Args: state (State): The global state. logger (Logger): The logger. """ del state, logger # unused pass
[docs] def before_loss(self, state: State, logger: Logger) -> None: """Called on the :attr:`~.Event.BEFORE_LOSS` event. Args: state (State): The global state. logger (Logger): The logger. """ del state, logger # unused pass
[docs] def after_loss(self, state: State, logger: Logger) -> None: """Called on the :attr:`~.Event.AFTER_LOSS` event. Args: state (State): The global state. logger (Logger): The logger. """ del state, logger # unused pass
[docs] def before_backward(self, state: State, logger: Logger) -> None: """Called on the :attr:`~.Event.BEFORE_BACKWARD` event. Args: state (State): The global state. logger (Logger): The logger. """ del state, logger # unused pass
[docs] def after_backward(self, state: State, logger: Logger) -> None: """Called on the :attr:`~.Event.AFTER_BACKWARD` event. Args: state (State): The global state. logger (Logger): The logger. """ del state, logger # unused pass
[docs] def after_train_batch(self, state: State, logger: Logger) -> None: """Called on the :attr:`~.Event.AFTER_TRAIN_BATCH` event. Args: state (State): The global state. logger (Logger): The logger. """ del state, logger # unused pass
[docs] def batch_end(self, state: State, logger: Logger) -> None: """Called on the :attr:`~.Event.BATCH_END` event. .. note:: The following :class:`~.time.Timer` member variables are incremented immediately before the :attr:`~.Event.BATCH_END` event. +-----------------------------------+ | :attr:`.Timer.batch` | +-----------------------------------+ | :attr:`.Timer.batch_in_epoch` | +-----------------------------------+ | :attr:`.Timer.sample` | +-----------------------------------+ | :attr:`.Timer.sample_in_epoch` | +-----------------------------------+ | :attr:`.Timer.token` | +-----------------------------------+ | :attr:`.Timer.token_in_epoch` | +-----------------------------------+ Args: state (State): The global state. logger (Logger): The logger. """ del state, logger # unused pass
[docs] def batch_checkpoint(self, state: State, logger: Logger) -> None: """Called on the :attr:`~.Event.BATCH_CHECKPOINT` event. Args: state (State): The global state. logger (Logger): The logger. """ del state, logger # unused pass
[docs] def epoch_end(self, state: State, logger: Logger) -> None: """Called on the :attr:`~.Event.EPOCH_END` event. .. note:: :class:`~.time.Timer` member variable :attr:`.Timer.epoch` is incremented immediately before :attr:`~.Event.EPOCH_END`. Args: state (State): The global state. logger (Logger): The logger. """ del state, logger # unused pass
[docs] def epoch_checkpoint(self, state: State, logger: Logger) -> None: """Called on the :attr:`~.Event.EPOCH_CHECKPOINT` event. Args: state (State): The global state. logger (Logger): The logger. """ del state, logger # unused pass
[docs] def eval_start(self, state: State, logger: Logger) -> None: """Called on the :attr:`~.Event.EVAL_START` event. Args: state (State): The global state. logger (Logger): The logger. """ del state, logger # unused pass
[docs] def eval_batch_start(self, state: State, logger: Logger) -> None: """Called on the :attr:`~.Event.EVAL_BATCH_START` event. Args: state (State): The global state. logger (Logger): The logger. """ del state, logger # unused pass
[docs] def eval_before_forward(self, state: State, logger: Logger) -> None: """Called on the :attr:`~.Event.EVAL_BATCH_FORWARD` event. Args: state (State): The global state. logger (Logger): The logger. """ del state, logger # unused pass
[docs] def eval_after_forward(self, state: State, logger: Logger) -> None: """Called on the :attr:`~.Event.EVAL_AFTER_FORWARD` event. Args: state (State): The global state. logger (Logger): The logger. """ del state, logger # unused pass
[docs] def eval_batch_end(self, state: State, logger: Logger) -> None: """Called on the :attr:`~.Event.EVAL_BATCH_END` event. Args: state (State): The global state. logger (Logger): The logger. """ del state, logger # unused pass
[docs] def eval_end(self, state: State, logger: Logger) -> None: """Called on the :attr:`~.Event.EVAL_END` event. Args: state (State): The global state. logger (Logger): The logger. """ del state, logger # unused pass
[docs] def close(self, state: State, logger: Logger) -> None: """Called whenever the trainer finishes training, even when there is an exception. It should be used for clean up tasks such as flushing I/O streams and/or closing any files that may have been opened during the :attr:`~.Event.INIT` event. Args: state (State): The global state. logger (Logger): The logger. """ pass
[docs] def post_close(self) -> None: """This hook is called after :meth:`close` has been invoked for each callback. Very few callbacks should need to implement :meth:`post_close`. This callback can be used to back up any data that may have been written by other callbacks during :meth:`close`. """ pass