โ๏ธ Callbacks#
Callbacks provide hooks that run at each training loopโs Event.
By convention, callbacks should not modify the
training loop by changing the State, but rather by reading and
logging various metrics. Typical callback use cases include logging, timing,
or model introspection.
Using Callbacks#
Built-in callbacks can be accessed in composer.callbacks and
registered with the callbacks argument to the Trainer.
from composer import Trainer
from composer.callbacks import SpeedMonitor, LRMonitor
from composer.loggers import WandBLogger
Trainer(
    model=model,
    train_dataloader=train_dataloader,
    eval_dataloader=None,
    max_duration='1ep',
    callbacks=[SpeedMonitor(window_size=100), LRMonitor()],
    loggers=[WandBLogger()],
)
This example includes callbacks that measure the model throughput and learning rate and logs them to Weights & Biases. Callbacks control what is being logged, whereas loggers specify where the information is being saved. For more information on loggers, see Logging.
Available Callbacks#
Composer provides several callbacks to monitor and log various components of training.
| Callback to save checkpoints. | |
| Logs the training throughput and utilization. | |
| Estimates total training time. | |
| Logs the learning rate. | |
| Computes and logs the L2 norm of gradients as well as any optimizer-specific metrics implemented in the optimizer's report_per_parameter_metrics method. | |
| Logs the memory usage of the model. | |
| Logs image inputs and optionally outputs. | |
| Create compliant results file for MLPerf Training benchmark. | |
| Halt training when a metric value reaches a certain threshold. | |
| Track a metric and halt training if it does not improve within a given interval. | |
| Callback to export model for inference. | 
Custom Callbacks#
Custom callbacks should inherit from Callback and override any of the
event-related hooks. For example, below is a simple callback that runs on
EPOCH_START and prints the epoch number.
from composer import Callback, State, Logger
class EpochMonitor(Callback):
    def epoch_start(self, state: State, logger: Logger):
        print(f'Epoch: {state.timestamp.epoch}')
Alternatively, one can override Callback.run_event() to run code
at every event. The below is an equivalent implementation for EpochMonitor:
from composer import Callback, Event, Logger, State
class EpochMonitor(Callback):
    def run_event(self, event: Event, state: State, logger: Logger):
        if event == Event.EPOCH_START:
            print(f'Epoch: {state.timestamp.epoch}')
Warning
If Callback.run_event() is overridden, the individual methods corresponding
to each event will be ignored.
The new callback can then be provided to the trainer.
from composer import Trainer
trainer = Trainer(
    ...,
    callbacks=[EpochMonitor()]
)
Events#
Here is the list of supported Event for callbacks to hook into.
- class composer.core.Event(value)[source]
- Enum to represent training loop events. - Events mark specific points in the training loop where an - Algorithmand- Callbackcan run.- The following pseudocode shows where each event fires in the training loop: - # <INIT> # <AFTER_LOAD> # <FIT_START> for epoch in range(NUM_EPOCHS): # <EPOCH_START> while True: # <BEFORE_DATALOADER> batch = next(dataloader) if batch is None: break # <AFTER_DATALOADER> # <BATCH_START> # <BEFORE_TRAIN_BATCH> for microbatch in batch.split(grad_accum): # <BEFORE_FORWARD> outputs = model(batch) # <AFTER_FORWARD> # <BEFORE_LOSS> loss = model.loss(outputs, batch) # <AFTER_LOSS> # <BEFORE_BACKWARD> loss.backward() # <AFTER_BACKWARD> # Un-scale and clip gradients # <AFTER_TRAIN_BATCH> optimizer.step() # <BATCH_END> if should_eval(batch=True): for eval_dataloader in eval_dataloaders: # <EVAL_START> for batch in eval_dataloader: # <EVAL_BATCH_START> # <EVAL_BEFORE_FORWARD> outputs, targets = model(batch) # <EVAL_AFTER_FORWARD> metrics.update(outputs, targets) # <EVAL_BATCH_END> # <EVAL_END> # <BATCH_CHECKPOINT> # <EPOCH_END> if should_eval(batch=False): for eval_dataloader in eval_dataloaders: # <EVAL_START> for batch in eval_dataloader: # <EVAL_BATCH_START> # <EVAL_BEFORE_FORWARD> outputs, targets = model(batch) # <EVAL_AFTER_FORWARD> metrics.update(outputs, targets) # <EVAL_BATCH_END> # <EVAL_END> # <EPOCH_CHECKPOINT> # <FIT_END> - INIT
- Invoked in the constructor of - Trainer. Model surgery (see- module_surgery) typically occurs here.
 - AFTER_LOAD
- Immediately after checkpoint is loaded in constructor of - Trainer.
 - FIT_START
- Invoked at the beginning of each call to - Trainer.fit(). Dataset transformations typically occur here.
 - EPOCH_START
- Start of an epoch. 
 - BEFORE_DATALOADER
- Immediately before the dataloader is called. 
 - AFTER_DATALOADER
- Immediately after the dataloader is called. Typically used for on-GPU dataloader transforms. 
 - BATCH_START
- Start of a batch. 
 - BEFORE_TRAIN_BATCH
- Before the forward-loss-backward computation for a training batch. When using gradient accumulation, this is still called only once. 
 - BEFORE_FORWARD
- Before the call to - model.forward(). This is called multiple times per batch when using gradient accumulation.
 - AFTER_FORWARD
- After the call to - model.forward(). This is called multiple times per batch when using gradient accumulation.
 - BEFORE_LOSS
- Before the call to - model.loss(). This is called multiple times per batch when using gradient accumulation.
 - AFTER_LOSS
- After the call to - model.loss(). This is called multiple times per batch when using gradient accumulation.
 - BEFORE_BACKWARD
- Before the call to - loss.backward(). This is called multiple times per batch when using gradient accumulation.
 - AFTER_BACKWARD
- After the call to - loss.backward(). This is called multiple times per batch when using gradient accumulation.
 - AFTER_TRAIN_BATCH
- After the forward-loss-backward computation for a training batch. When using gradient accumulation, this event still fires only once. 
 - BATCH_END
- End of a batch, which occurs after the optimizer step and any gradient scaling. 
 - BATCH_CHECKPOINT
- After - Event.BATCH_ENDand any batch-wise evaluation. Saving checkpoints at this event allows the checkpoint saver to use the results from any batch-wise evaluation to determine whether a checkpoint should be saved.
 - EPOCH_END
- End of an epoch. 
 - EPOCH_CHECKPOINT
- After - Event.EPOCH_ENDand any epoch-wise evaluation. Saving checkpoints at this event allows the checkpoint saver to use the results from any epoch-wise evaluation to determine whether a checkpointshould be saved.
 - FIT_END
- Invoked at the end of each call to - Trainer.fit(). This event exists primarily for logging information and flushing callbacks. Algorithms should not transform the training state on this event, as any changes will not be preserved in checkpoints.
 - EVAL_START
- Start of evaluation through the validation dataset. 
 - EVAL_BATCH_START
- Before the call to - model.eval_forward(batch)
 - EVAL_BEFORE_FORWARD
- Before the call to - model.eval_forward(batch)
 - EVAL_AFTER_FORWARD
- After the call to - model.eval_forward(batch)
 - EVAL_BATCH_END
- After the call to - model.eval_forward(batch)
 - EVAL_END
- End of evaluation through the validation dataset.