๐ Early Stopping#
Early stopping and threshold stopping halt training based on set criteria. In Composer, this functionality is implemented as callbacks which can be configured and passed to the Trainer.
Early Stopping#
The EarlyStopper
callback stops training if a provided metric does not improve over a certain patience
window of time.
import torch
from composer import Trainer
from composer.callbacks.early_stopper import EarlyStopper
early_stopper = EarlyStopper(
monitor='MulticlassAccuracy',
dataloader_label='train',
patience='50ba',
comp=torch.greater,
min_delta=0.01,
)
trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
optimizers=optimizer,
callbacks=[early_stopper],
max_duration="1ep",
)
In the above example, the 'train'
label means the callback is tracking the MulticlassAccuracy
metric for the train_dataloader. The default for the evaluation dataloader is eval
.
We also set patience='50ba'
and min_delta=0.01
which means that every 50 batches, if the Accuracy does not exceed the best recorded Accuracy by 0.01
, training is stopped. The comp
argument indicates that โbetterโ here means higher accuracy. Note that the patience
parameter can take both a time string (see Time) or an integer which specifies a number of epochs.
For a full list of arguments, see the documentation for EarlyStopper.
Note
When monitoring metrics from the eval_dataloader
, make sure that your patience is at least a few multiples of the eval_interval
(e.g. if eval_interval='1ep'
, patience='4ep'
), so that the callback has a few datapoints with which to measure improvement.
Threshold Stopping#
The ThresholdStopper`
callback also monitors a specific metric, but halts training when that metric reaches a certain threshold.
from composer import Trainer
from composer.callbacks.threshold_stopper import ThresholdStopper
threshold_stopper = ThresholdStopper(
monitor='MulticlassAccuracy',
dataloader_label='eval',
threshold=0.8,
)
trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
optimizers=optimizer,
callbacks=[threshold_stopper],
max_duration='1ep',
)
In this example, training will exit when the modelโs validation accuracy exceeds 0.8. For a full list of arguments, see the documentation for ThresholdStopper.
Evaluators and Multiple Metrics#
When there are multiple datasets and metrics to use for validation and evaluation, Evaluator
objects can be used to pass in multiple dataloaders to the trainer. Each of the Evaluator
objects can have multiple metrics associated. See Evaluation for more details.
Each Evaluator object is marked with a label
field for logging, and a metric_names
field that accepts a list of metric names. These can be provided to the callbacks above to indiciate which metric to monitor.
In the example below, the callback will monitor the MulticlassAccuracy metric in the dataloader marked eval_dataset1.`
from composer import Trainer, Evaluator
from composer.callbacks.early_stopper import EarlyStopper
evaluator1 = Evaluator(
label='eval_dataset1',
dataloader=eval_dataloader,
metric_names=['MulticlassAccuracy']
)
evaluator2 = Evaluator(
label='eval_dataset2',
dataloader=eval_dataloader2,
metric_names=['MulticlassAccuracy']
)
early_stopper = EarlyStopper(
monitor='MulticlassAccuracy',
dataloader_label='eval_dataset1',
patience=1
)
trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
eval_dataloader=[evaluator1, evaluator2],
optimizers=optimizer,
callbacks=[early_stopper],
max_duration="1ep",
)