Callback#

class composer.Callback(*args, **kwargs)[source]#

Base class for callbacks.

Callbacks provide hooks that can run at each training loop Event. A callback is similar to an Algorithm in that they are run on specific events, but it differs from an Algorithm in that it should not modify the training of the model. By convention, callbacks should not modify the State. They are typically used to for non-essential recording functions such as logging or timing.

Callbacks can be implemented in two ways:

  1. Override the individual methods named for each Event.

    For example,

    >>> class MyCallback(Callback):
    ...     def epoch_start(self, state: State, logger: Logger):
    ...         print(f'Epoch: {int(state.timestamp.epoch)}')
    >>> # construct trainer object with your callback
    >>> trainer = Trainer(
    ...     model=model,
    ...     train_dataloader=train_dataloader,
    ...     eval_dataloader=eval_dataloader,
    ...     optimizers=optimizer,
    ...     max_duration="1ep",
    ...     callbacks=[MyCallback()],
    ... )
    >>> # trainer will run MyCallback whenever the EPOCH_START
    >>> # is triggered, like this:
    >>> _ = trainer.engine.run_event(Event.EPOCH_START)
    Epoch: 0
    
  2. Override run_event() if you want a single method to handle all events. If this method is overridden, then the individual methods corresponding to each event name (such as epoch_start()) will no longer be automatically invoked. For example, if you override run_event(), then epoch_start() will not be called on the Event.EPOCH_START event, batch_start() will not be called on the Event.BATCH_START, etc. However, you can invoke epoch_start(), batch_start(), etc. in your overriding implementation of run_event().

    For example,

    >>> class MyCallback(Callback):
    ...     def run_event(self, event: Event, state: State, logger: Logger):
    ...         if event == Event.EPOCH_START:
    ...             print(f'Epoch: {int(state.timestamp.epoch)}')
    >>> # construct trainer object with your callback
    >>> trainer = Trainer(
    ...     model=model,
    ...     train_dataloader=train_dataloader,
    ...     eval_dataloader=eval_dataloader,
    ...     optimizers=optimizer,
    ...     max_duration="1ep",
    ...     callbacks=[MyCallback()],
    ... )
    >>> # trainer will run MyCallback whenever the EPOCH_START
    >>> # is triggered, like this:
    >>> _ = trainer.engine.run_event(Event.EPOCH_START)
    Epoch: 0
    
after_backward(state, logger)[source]#

Called on the Event.AFTER_BACKWARD event.

Parameters
  • state (State) โ€“ The training state.

  • logger (Logger) โ€“ The logger.

after_dataloader(state, logger)[source]#

Called on the Event.AFTER_DATALOADER event.

Parameters
  • state (State) โ€“ The training state.

  • logger (Logger) โ€“ The logger.

after_forward(state, logger)[source]#

Called on the Event.AFTER_FORWARD event.

Parameters
  • state (State) โ€“ The training state.

  • logger (Logger) โ€“ The logger.

after_load(state, logger)[source]#

Called on the Event.AFTER_LOAD event.

Parameters
  • state (State) โ€“ The training state.

  • logger (Logger) โ€“ The logger.

after_loss(state, logger)[source]#

Called on the Event.AFTER_LOSS event.

Parameters
  • state (State) โ€“ The training state.

  • logger (Logger) โ€“ The logger.

after_train_batch(state, logger)[source]#

Called on the Event.AFTER_TRAIN_BATCH event.

Parameters
  • state (State) โ€“ The training state.

  • logger (Logger) โ€“ The logger.

batch_checkpoint(state, logger)[source]#

Called on the Event.BATCH_CHECKPOINT event.

Parameters
  • state (State) โ€“ The training state.

  • logger (Logger) โ€“ The logger.

batch_end(state, logger)[source]#

Called on the Event.BATCH_END event.

Parameters
  • state (State) โ€“ The training state.

  • logger (Logger) โ€“ The logger.

batch_start(state, logger)[source]#

Called on the Event.BATCH_START event.

Parameters
  • state (State) โ€“ The training state.

  • logger (Logger) โ€“ The logger.

before_backward(state, logger)[source]#

Called on the Event.BEFORE_BACKWARD event.

Parameters
  • state (State) โ€“ The training state.

  • logger (Logger) โ€“ The logger.

before_dataloader(state, logger)[source]#

Called on the Event.BEFORE_DATALOADER event.

Parameters
  • state (State) โ€“ The training state.

  • logger (Logger) โ€“ The logger.

before_forward(state, logger)[source]#

Called on the Event.BEFORE_FORWARD event.

Parameters
  • state (State) โ€“ The training state.

  • logger (Logger) โ€“ The logger.

before_loss(state, logger)[source]#

Called on the Event.BEFORE_LOSS event.

Parameters
  • state (State) โ€“ The training state.

  • logger (Logger) โ€“ The logger.

before_train_batch(state, logger)[source]#

Called on the Event.BEFORE_TRAIN_BATCH event.

Parameters
  • state (State) โ€“ The training state.

  • logger (Logger) โ€“ The logger.

close(state, logger)[source]#

Called whenever the trainer finishes training, even when there is an exception.

It should be used for clean up tasks such as flushing I/O streams and/or closing any files that may have been opened during the Event.INIT event.

Parameters
  • state (State) โ€“ The training state.

  • logger (Logger) โ€“ The logger.

epoch_checkpoint(state, logger)[source]#

Called on the Event.EPOCH_CHECKPOINT event.

Parameters
  • state (State) โ€“ The training state.

  • logger (Logger) โ€“ The logger.

epoch_end(state, logger)[source]#

Called on the Event.EPOCH_END event.

Note

State.timestamp member variable Timestamp.epoch is incremented immediately before Event.EPOCH_END.

Parameters
  • state (State) โ€“ The training state.

  • logger (Logger) โ€“ The logger.

epoch_start(state, logger)[source]#

Called on the Event.EPOCH_START event.

Parameters
  • state (State) โ€“ The training state.

  • logger (Logger) โ€“ The logger.

eval_after_forward(state, logger)[source]#

Called on the Event.EVAL_AFTER_FORWARD event.

Parameters
  • state (State) โ€“ The training state.

  • logger (Logger) โ€“ The logger.

eval_batch_end(state, logger)[source]#

Called on the Event.EVAL_BATCH_END event.

Parameters
  • state (State) โ€“ The training state.

  • logger (Logger) โ€“ The logger.

eval_batch_start(state, logger)[source]#

Called on the Event.EVAL_BATCH_START event.

Parameters
  • state (State) โ€“ The training state.

  • logger (Logger) โ€“ The logger.

eval_before_forward(state, logger)[source]#

Called on the Event.EVAL_BATCH_FORWARD event.

Parameters
  • state (State) โ€“ The training state.

  • logger (Logger) โ€“ The logger.

eval_end(state, logger)[source]#

Called on the Event.EVAL_END event.

Parameters
  • state (State) โ€“ The training state.

  • logger (Logger) โ€“ The logger.

eval_start(state, logger)[source]#

Called on the Event.EVAL_START event.

Parameters
  • state (State) โ€“ The training state.

  • logger (Logger) โ€“ The logger.

fit_end(state, logger)[source]#

Called on the Event.FIT_END event.

Parameters
  • state (State) โ€“ The training state.

  • logger (Logger) โ€“ The logger.

fit_start(state, logger)[source]#

Called on the Event.FIT_START event.

Parameters
  • state (State) โ€“ The training state.

  • logger (Logger) โ€“ The logger.

init(state, logger)[source]#

Called on the Event.INIT event.

Parameters
  • state (State) โ€“ The training state.

  • logger (Logger) โ€“ The logger.

post_close()[source]#

Called after close() has been invoked for each callback.

Very few callbacks should need to implement post_close(). This callback can be used to back up any data that may have been written by other callbacks during close().

predict_after_forward(state, logger)[source]#

Called on the Event.PREDICT_AFTER_FORWARD event.

Parameters
  • state (State) โ€“ The training state.

  • logger (Logger) โ€“ The logger.

predict_batch_end(state, logger)[source]#

Called on the Event.PREDICT_BATCH_END event.

Parameters
  • state (State) โ€“ The training state.

  • logger (Logger) โ€“ The logger.

predict_batch_start(state, logger)[source]#

Called on the Event.PREDICT_BATCH_START event.

Parameters
  • state (State) โ€“ The training state.

  • logger (Logger) โ€“ The logger.

predict_before_forward(state, logger)[source]#

Called on the Event.PREDICT_BATCH_FORWARD event.

Parameters
  • state (State) โ€“ The training state.

  • logger (Logger) โ€“ The logger.

predict_end(state, logger)[source]#

Called on the Event.PREDICT_END event.

Parameters
  • state (State) โ€“ The training state.

  • logger (Logger) โ€“ The logger.

predict_start(state, logger)[source]#

Called on the Event.PREDICT_START event.

Parameters
  • state (State) โ€“ The training state.

  • logger (Logger) โ€“ The logger.

run_event(event, state, logger)[source]#

Called by the engine on each event.

Parameters
  • event (Event) โ€“ The event.

  • state (State) โ€“ The state.

  • logger (Logger) โ€“ The logger.