๐ Early Stopping#
In this tutorial, weโre going to learn about how to perform early stopping in Composer using callbacks.
In Composer, callbacks modify trainer behavior and are called at the relevant Events in the training loop. This tutorial focuses on two callbacks, the EarlyStopper and ThresholdStopper, both of which halt training early depending on different criteria.
Recommended Background#
This tutorial assumes that youโre generally familiar with techniques such as early stopping.
You should also be comfortable with the material in the Getting Started tutorial before you embark on this slightly more advanced tutorial.
Finally, youโll probably have a better intuition for how the demonstrated features work if you brush up on Composerโs event-driven design in our Welcome Tour.
Tutorial Goals and Concepts Covered#
The goal of this tutorial is to demonstrate a basic training run using one of our callbacks to control if/when training stops before the maximum training duration.
Weโll demonstrate:
A comprehensive overview of Composer callbacks is outside the scope of this tutorial, but this should introduce you to some useful tools and give you a sense for the ways callbacks can be used to modify training behavior.
Letโs get started!
Setup#
In this tutorial, weโll train a ComposerModel
and halt training for criteria that weโll set. Weโll use the same basic setup as in the Getting Started tutorial. If you want to better understand the details of the setup, thatโs a good place to review.
Install Composer#
First, install Composer if you havenโt already:
[ ]:
%pip install mosaicml
# To install from source instead of the last release, comment the command above and uncomment the following one.
# %pip install git+https://github.com/mosaicml/composer.git
Seed#
Next, weโll set the seed for reproducibility:
[ ]:
from composer.utils.reproducibility import seed_all
seed_all(42)
Dataloader Setup#
Next, instantiate the training and evaluation datasets for CIFAR10
[ ]:
import torch.utils.data
from torchvision import datasets, transforms
data_directory = "./data"
# Normalization constants
mean = (0.507, 0.487, 0.441)
std = (0.267, 0.256, 0.276)
batch_size = 1024
cifar10_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
train_dataset = datasets.CIFAR10(data_directory, train=True, download=True, transform=cifar10_transforms)
eval_dataset = datasets.CIFAR10(data_directory, train=False, download=True, transform=cifar10_transforms)
# Setting shuffle=False to allow for easy overfitting in this example
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
eval_dataloader = torch.utils.data.DataLoader(eval_dataset, batch_size=batch_size, shuffle=False)
Model, Optimizer, Scheduler, and Evaluator Setup#
Finally, set up the model, optimizer, scheduler, and an evaluator.
[ ]:
from composer.models import ComposerClassifier
from composer.optim import DecoupledSGDW, LinearWithWarmupScheduler
from composer.core import Evaluator
import torch
import torch.nn as nn
import torch.nn.functional as F
class Block(nn.Module):
"""A ResNet block."""
def __init__(self, f_in: int, f_out: int, downsample: bool = False):
super(Block, self).__init__()
stride = 2 if downsample else 1
self.conv1 = nn.Conv2d(f_in, f_out, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(f_out)
self.conv2 = nn.Conv2d(f_out, f_out, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(f_out)
self.relu = nn.ReLU(inplace=True)
# No parameters for shortcut connections.
if downsample or f_in != f_out:
self.shortcut = nn.Sequential(
nn.Conv2d(f_in, f_out, kernel_size=1, stride=2, bias=False),
nn.BatchNorm2d(f_out),
)
else:
self.shortcut = nn.Sequential()
def forward(self, x: torch.Tensor):
out = self.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
return self.relu(out)
class ResNetCIFAR(nn.Module):
"""A residual neural network as originally designed for CIFAR-10."""
def __init__(self, outputs: int = 10):
super(ResNetCIFAR, self).__init__()
depth = 56
width = 16
num_blocks = (depth - 2) // 6
plan = [(width, num_blocks), (2 * width, num_blocks), (4 * width, num_blocks)]
self.num_classes = outputs
# Initial convolution.
current_filters = plan[0][0]
self.conv = nn.Conv2d(3, current_filters, kernel_size=3, stride=1, padding=1, bias=False)
self.bn = nn.BatchNorm2d(current_filters)
self.relu = nn.ReLU(inplace=True)
# The subsequent blocks of the ResNet.
blocks = []
for segment_index, (filters, num_blocks) in enumerate(plan):
for block_index in range(num_blocks):
downsample = segment_index > 0 and block_index == 0
blocks.append(Block(current_filters, filters, downsample))
current_filters = filters
self.blocks = nn.Sequential(*blocks)
# Final fc layer. Size = number of filters in last segment.
self.fc = nn.Linear(plan[-1][0], outputs)
self.criterion = nn.CrossEntropyLoss()
def forward(self, x: torch.Tensor):
out = self.relu(self.bn(self.conv(x)))
out = self.blocks(out)
out = F.avg_pool2d(out, out.size()[3])
out = out.view(out.size(0), -1)
out = self.fc(out)
return out
model = ComposerClassifier(module=ResNetCIFAR(), num_classes=10)
optimizer = DecoupledSGDW(
model.parameters(), # Model parameters to update
lr=0.05, # Peak learning rate
momentum=0.9,
weight_decay=2.0e-3 # If this looks large, it's because its not scaled by the LR as in non-decoupled weight decay
)
lr_scheduler = LinearWithWarmupScheduler(
t_warmup="1ep", # Warm up over 1 epoch
alpha_i=1.0, # Flat LR schedule achieved by having alpha_i == alpha_f
alpha_f=1.0
)
evaluator = Evaluator(
dataloader = eval_dataloader,
label = "eval",
metric_names = ['MulticlassAccuracy']
)
EarlyStopper#
The EarlyStopper
callback tracks a particular training or evaluation metric and stops training if the metric does not improve within a given time interval.
The callback takes the following parameters:
monitor
: The name of the metric to trackdataloader_label
: This string identifies which dataloader the metric belongs to. By default, the train dataloader is labeledtrain
, and the evaluation dataloader is labeledeval
. (These names can be customized via thetrain_dataloader_label
in the Trainer or thelabel
argument of the Evaluator, respectively.)patience
: The interval of the time that the callback will wait before stopping training if the metric is not improving. You can use integers to specify the number of epochs or provide a Time stringโe.g.,"50ba"
or"2ep"
for 50 batches and 2 epochs, respectively.min_delta
: If non-zero, the change in the tracked metric over thepatience
window must be at least this large.comp
: A comparison operator can be provided to measure the change in the monitored metric. The comparison operator will be called likecomp(current_value, previous_best)
See the API Reference for more information.
Here, weโll use our callback to track the MulticlassAccuracy metric over one epoch on the test dataset:
[ ]:
from composer.callbacks import EarlyStopper
early_stopper = EarlyStopper(monitor="MulticlassAccuracy", dataloader_label="eval", patience=1)
Now that we have our callback, we can instantiate a Composer trainer and train:
[ ]:
from composer.trainer import Trainer
# Early stopping should stop training before we reach 100 epochs!
train_epochs = "100ep"
trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
eval_dataloader=evaluator,
max_duration=train_epochs,
optimizers=optimizer,
schedulers=lr_scheduler,
callbacks=[early_stopper], # Instruct the trainer to use our early stopping callback
train_subset_num_batches=10, # Only training on a subset of the data to trigger the callback sooner
)
# Train!
trainer.fit()
ThresholdStopper#
The ThresholdStopper callback is similar to the EarlyStopper, but it halts training when the metric crosses a threshold set in the ThresholdStopper callback.
This callback takes the following parameters:
monitor
,dataloader_label
, andcomp
: Same as the EarlyStopper callbackthreshold
: The float threshold that dictates when to halt training.stop_on_batch
: IfTrue
, training can halt in the middle of an epoch, rather than just add the end.
We will reuse the same setup for the ThresholdStopper example.
[ ]:
from composer.callbacks import ThresholdStopper
threshold_stopper = ThresholdStopper("MulticlassAccuracy", "eval", threshold=0.3)
# Threshold stopping should stop training before we reach 100 epochs!
train_epochs = "100ep"
trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
eval_dataloader=evaluator,
max_duration=train_epochs,
optimizers=optimizer,
schedulers=lr_scheduler,
callbacks=[threshold_stopper], # Instruct the trainer to use our threshold stopper callback
train_subset_num_batches=10, # Only training on a subset of the data to trigger the callback sooner
)
# Train!
trainer.fit()
What next?#
Youโve now seen how to implement early stopping in Composer using our EarlyStopper
and ThresholdStopper
callbacks.
To dig deeper into Composer callbacks check out the docs and our API references.
In addition, please continue to explore our tutorials! Hereโs a couple suggestions:
Continue learning about other Composer features like automatic gradient accumulation and automatic restarting from checkpoints
Give your model life after training with Composerโs export for inference tools
Come get involved with MosaicML!#
Weโd love for you to get involved with the MosaicML community in any of these ways:
Star Composer on GitHub#
Help make others aware of our work by starring Composer on GitHub.
Join the MosaicML Slack#
Head on over to the MosaicML slack to join other ML efficiency enthusiasts. Come for the paper discussions, stay for the memes!
Contribute to Composer#
Is there a bug you noticed or a feature youโd like? File an issue or make a pull request!