☎️ Callbacks#

Callbacks provide hooks that run at each training loop’s Event. By convention, callbacks should not modify the training loop by changing the State, but rather by reading and logging various metrics. Typical callback use cases include logging, timing, or model introspection.

Using Callbacks#

Built-in callbacks can be accessed in composer.callbacks and registered with the callbacks argument to the Trainer.

from composer import Trainer
from composer.callbacks import SpeedMonitor, LRMonitor
from composer.loggers import WandBLogger

Trainer(
    model=model,
    train_dataloader=train_dataloader,
    eval_dataloader=None,
    max_duration='1ep',
    callbacks=[SpeedMonitor(window_size=100), LRMonitor()],
    loggers=[WandBLogger()],
)

This example includes callbacks that measure the model throughput and learning rate and logs them to Weights & Biases. Callbacks control what is being logged, whereas loggers specify where the information is being saved. For more information on loggers, see Logging.

Available Callbacks#

Composer provides several callbacks to monitor and log various components of training.

CheckpointSaver

Callback to save checkpoints.

SpeedMonitor

Logs the training throughput and utilization.

RuntimeEstimator

Estimates total training time.

LRMonitor

Logs the learning rate.

OptimizerMonitor

Computes and logs the L2 norm of gradients as well as any optimizer-specific metrics implemented in the optimizer's report_per_parameter_metrics method.

MemoryMonitor

Logs the memory usage of the model.

MemorySnapshot

Logs the memory snapshot of the model.

OOMObserver

Generate visualizations of the state of allocated memory during an OutOfMemory exception.

NaNMonitor

Catches NaNs in the loss and raises an error if one is found.

ImageVisualizer

Logs image inputs and optionally outputs.

MLPerfCallback

Create compliant results file for MLPerf Training benchmark.

ThresholdStopper

Halt training when a metric value reaches a certain threshold.

EarlyStopper

Track a metric and halt training if it does not improve within a given interval.

ExportForInferenceCallback

Callback to export model for inference.

Custom Callbacks#

Custom callbacks should inherit from Callback and override any of the event-related hooks. For example, below is a simple callback that runs on EPOCH_START and prints the epoch number.

from composer import Callback, State, Logger

class EpochMonitor(Callback):

    def epoch_start(self, state: State, logger: Logger):
        print(f'Epoch: {state.timestamp.epoch}')

Alternatively, one can override Callback.run_event() to run code at every event. The below is an equivalent implementation for EpochMonitor:

from composer import Callback, Event, Logger, State

class EpochMonitor(Callback):

    def run_event(self, event: Event, state: State, logger: Logger):
        if event == Event.EPOCH_START:
            print(f'Epoch: {state.timestamp.epoch}')

Warning

If Callback.run_event() is overridden, the individual methods corresponding to each event will be ignored.

The new callback can then be provided to the trainer.

from composer import Trainer

trainer = Trainer(
    ...,
    callbacks=[EpochMonitor()]
)

Events#

Here is the list of supported Event for callbacks to hook into.

class composer.core.Event(value)[source]

Enum to represent training loop events.

Events mark specific points in the training loop where an Algorithm and Callback can run.

The following pseudocode shows where each event fires in the training loop:

# <INIT>
# <BEFORE_LOAD>
# <AFTER_LOAD>
# <FIT_START>
for iteration in range(NUM_ITERATIONS):
    # <ITERATION_START>
    for epoch in range(NUM_EPOCHS):
        # <EPOCH_START>
        while True:
            # <BEFORE_DATALOADER>
            batch = next(dataloader)
            if batch is None:
                break
            # <AFTER_DATALOADER>

            # <BATCH_START>

            # <BEFORE_TRAIN_BATCH>

            for microbatch in batch.split(device_train_microbatch_size):

                # <BEFORE_FORWARD>
                outputs = model(batch)
                # <AFTER_FORWARD>

                # <BEFORE_LOSS>
                loss = model.loss(outputs, batch)
                # <AFTER_LOSS>

                # <BEFORE_BACKWARD>
                loss.backward()
                # <AFTER_BACKWARD>

            # Un-scale gradients

            # <AFTER_TRAIN_BATCH>
            optimizer.step()

            # <BATCH_END>

            # <EVAL_BEFORE_ALL>
            for eval_dataloader in eval_dataloaders:
                if should_eval(batch=True):
                    # <EVAL_START>
                    for batch in eval_dataloader:
                        # <EVAL_BATCH_START>
                        # <EVAL_BEFORE_FORWARD>
                        outputs, targets = model(batch)
                        # <EVAL_AFTER_FORWARD>
                        metrics.update(outputs, targets)
                        # <EVAL_BATCH_END>
                    # <EVAL_END>

            # <EVAL_AFTER_ALL>

            # <BATCH_CHECKPOINT>
        # <EPOCH_END>

        # <BEFORE_EVAL_ALL>
        for eval_dataloader in eval_dataloaders:
            if should_eval(batch=True):
                # <EVAL_START>
                for batch in eval_dataloader:
                    # <EVAL_BATCH_START>
                    # <EVAL_BEFORE_FORWARD>
                    outputs, targets = model(batch)
                    # <EVAL_AFTER_FORWARD>
                    metrics.update(outputs, targets)
                    # <EVAL_BATCH_END>
                # <EVAL_END>

        # <AFTER_EVAL_ALL>

        # <EPOCH_CHECKPOINT>
    # <ITERATION_END>
    # <ITERATION_CHECKPOINT>
# <FIT_END>
INIT

Invoked in the constructor of Trainer. Model surgery (see module_surgery) typically occurs here.

BEFORE_LOAD

Immediately before the checkpoint is loaded in Trainer.

AFTER_LOAD

Immediately after checkpoint is loaded in constructor of Trainer.

FIT_START

Invoked at the beginning of each call to Trainer.fit(). Dataset transformations typically occur here.

ITERATION_START

Start of an iteration.

EPOCH_START

Start of an epoch.

BEFORE_DATALOADER

Immediately before the dataloader is called.

AFTER_DATALOADER

Immediately after the dataloader is called. Typically used for on-GPU dataloader transforms.

BATCH_START

Start of a batch.

BEFORE_TRAIN_BATCH

Before the forward-loss-backward computation for a training batch. When using gradient accumulation, this is still called only once.

BEFORE_FORWARD

Before the call to model.forward(). This is called multiple times per batch when using gradient accumulation.

AFTER_FORWARD

After the call to model.forward(). This is called multiple times per batch when using gradient accumulation.

BEFORE_LOSS

Before the call to model.loss(). This is called multiple times per batch when using gradient accumulation.

AFTER_LOSS

After the call to model.loss(). This is called multiple times per batch when using gradient accumulation.

BEFORE_BACKWARD

Before the call to loss.backward(). This is called multiple times per batch when using gradient accumulation.

AFTER_BACKWARD

After the call to loss.backward(). This is called multiple times per batch when using gradient accumulation.

AFTER_TRAIN_BATCH

After the forward-loss-backward computation for a training batch. When using gradient accumulation, this event still fires only once.

BATCH_END

End of a batch, which occurs after the optimizer step and any gradient scaling.

BATCH_CHECKPOINT

After Event.BATCH_END and any batch-wise evaluation. Saving checkpoints at this event allows the checkpoint saver to use the results from any batch-wise evaluation to determine whether a checkpoint should be saved.

EPOCH_END

End of an epoch.

EPOCH_CHECKPOINT

After Event.EPOCH_END and any epoch-wise evaluation. Saving checkpoints at this event allows the checkpoint saver to use the results from any epoch-wise evaluation to determine whether a checkpoint should be saved.

ITERATION_END

End of an iteration.

ITERATION_CHECKPOINT

After Event.ITERATION_END. Saving checkpoints at this event allows the checkpoint

saver to determine whether a checkpoint should be saved.
FIT_END

Invoked at the end of each call to Trainer.fit(). This event exists primarily for logging information and flushing callbacks. Algorithms should not transform the training state on this event, as any changes will not be preserved in checkpoints.

EVAL_BEFORE_ALL

Before any evaluators process validation dataset.

EVAL_START

Start of evaluation through the validation dataset.

EVAL_BATCH_START

Before the call to model.eval_forward(batch)

EVAL_BEFORE_FORWARD

Before the call to model.eval_forward(batch)

EVAL_AFTER_FORWARD

After the call to model.eval_forward(batch)

EVAL_BATCH_END

After the call to model.eval_forward(batch)

EVAL_END

End of evaluation through the validation dataset.

EVAL_AFTER_ALL

After all evaluators process validation dataset.

EVAL_STANDALONE_START

Start of evaluation through a direct call to trainer.eval.

EVAL_STANDALONE_END

End of evaluation through a direct call to trainer.eval.