🛑 Early Stopping#
Early stopping and threshold stopping halt training based on set criteria. In Composer, this functionality is implemented as callbacks which can be configured and passed to the Trainer.
EarlyStopper callback stops training if a provided metric does not improve over a certain
patience window of time.
import torch from composer import Trainer from composer.callbacks.early_stopper import EarlyStopper early_stopper = EarlyStopper( monitor='MulticlassAccuracy', dataloader_label='train', patience='50ba', comp=torch.greater, min_delta=0.01, ) trainer = Trainer( model=model, train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, optimizers=optimizer, callbacks=[early_stopper], max_duration="1ep", )
In the above example, the
'train' label means the callback is tracking the
MulticlassAccuracy metric for the train_dataloader. The default for the evaluation dataloader is
We also set
min_delta=0.01 which means that every 50 batches, if the Accuracy does not exceed the best recorded Accuracy by
0.01, training is stopped. The
comp argument indicates that ‘better’ here means higher accuracy. Note that the
patience parameter can take both a time string (see Time) or an integer which specifies a number of epochs.
For a full list of arguments, see the documentation for
When monitoring metrics from the
eval_dataloader, make sure that your patience is at least a few multiples of the
eval_interval (e.g. if
patience='4ep'), so that the callback has a few datapoints with which to measure improvement.
ThresholdStopper` callback also monitors a specific metric, but halts training when that metric reaches a certain threshold.
from composer import Trainer from composer.callbacks.threshold_stopper import ThresholdStopper threshold_stopper = ThresholdStopper( monitor='MulticlassAccuracy', dataloader_label='eval', threshold=0.8, ) trainer = Trainer( model=model, train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, optimizers=optimizer, callbacks=[threshold_stopper], max_duration='1ep', )
In this example, training will exit when the model’s validation accuracy exceeds 0.8. For a full list of arguments, see the documentation for
Evaluators and Multiple Metrics#
When there are multiple datasets and metrics to use for validation and evaluation,
Evaluator objects can be used to pass in multiple dataloaders to the trainer. Each of the
Evaluator objects can have multiple metrics associated. See Evaluation for more details.
Each Evaluator object is marked with a
label field for logging, and a
metric_names field that accepts a list of metric names. These can be provided to the callbacks above to indiciate which metric to monitor.
In the example below, the callback will monitor the MulticlassAccuracy metric in the dataloader marked eval_dataset1.`
from composer import Trainer, Evaluator from composer.callbacks.early_stopper import EarlyStopper evaluator1 = Evaluator( label='eval_dataset1', dataloader=eval_dataloader, metric_names=['MulticlassAccuracy'] ) evaluator2 = Evaluator( label='eval_dataset2', dataloader=eval_dataloader2, metric_names=['MulticlassAccuracy'] ) early_stopper = EarlyStopper( monitor='MulticlassAccuracy', dataloader_label='eval_dataset1', patience=1 ) trainer = Trainer( model=model, train_dataloader=train_dataloader, eval_dataloader=[evaluator1, evaluator2], optimizers=optimizer, callbacks=[early_stopper], max_duration="1ep", )