๐ Gradient Clipping#
[How to Use] - [Suggested Hyperparameters] - [Attribution] - [API Reference]
Computer Vision
, Natural Language Processing
Gradient Clipping is a technique used to stabilize the training of neural networks. It was
originally invented to solve the problem of vanishing and exploding gradients in training recurrent neural networks, but it has also shown to be useful for transformers and convolutional neural networks.
Gradient clipping usually consists of clipping the extreme values of a modelโs gradients (or the gradientsโ norms) to be under
a certain threshold. The gradient clipping operation is executed after gradients are computed (after loss.backward()
), but before the weights of the network are updated (optim.step()
).
How to Use#
The desired gradient clipping type can be controlled using the clipping_type
argument.
The Different Flavors of Gradient Clipping#
Gradient clipping by value:#
Constrains all gradients to be between \([-\lambda, \lambda]\), where \(\lambda\) is
the clipping_threshold
.
import composer.functional as cf
cf.apply_gradient_clipping(model.parameters(),
clipping_type='value',
clipping_threshold=clipping_threshold)
Gradient clipping by norm:#
Multiplies all gradients by \(\min(1, \frac{\lambda}{||G||})\), where \(\lambda\) is
the clipping_threshold
and \(||G||\) is the total L2 norm of all gradients.
import composer.functional as cf
cf.apply_gradient_clipping(model.parameters(),
clipping_type='norm',
clipping_threshold=clipping_threshold)
Adaptive Gradient Clipping (AGC):#
Clips all gradients based on the gradient norm to parameter norm ratio by multiplying them by
\(\min(1, \lambda\frac{||W||}{||G||})\), where \(\lambda\) is the clipping_threshold
,
\(||G||\) is the norm of the gradients and \(||W||\) is the norm of the weights.
import composer.functional as cf
cf.apply_gradient_clipping(model.parameters(),
clipping_type='adaptive',
clipping_threshold=clipping_threshold)
Functional Interface#
# Run gradient clipping directly on the model right after a loss.backward() call
# using the Composer functional API.
import torch
import composer.functional as cf
clipping_type = 'norm' # can also be 'adaptive' or 'value'
def training_loop(model, train_loader):
opt = torch.optim.Adam(model.parameters())
loss_fn = F.cross_entropy
model.train()
for epoch in range(num_epochs):
for X, y in train_loader:
opt.zero_grad()
y_hat = model(X)
loss = loss_fn(y_hat, y)
loss.backward()
cf.apply_gradient_clipping(model.parameters(), clipping_type=clipping_type,
clipping_threshold=0.1)
opt.step()
Composer Trainer#
# Instantiate the algorithm and pass it into the Trainer
# The trainer will automatically run it at the appropriate points in the training loop
from composer.algorithms import GradientClipping
from composer.trainer import Trainer
clipping_type = 'norm' # can also be 'adaptive' or 'value'
gc = GradientClipping(clipping_type=clipping_type, clipping_threshold=0.1)
trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
max_duration='1ep',
algorithms=[gc]
)
trainer.fit()
Implementation Details#
clipping_type='norm'
#
Norm-based gradient clipping is implemented as follows:
On Event.AFTER_TRAIN_BATCH
, for every parameter in the model that has gradients:
Compute the parameterโs gradients and concatenate all parametersโ gradients into one big vector
Compute the norm of all the gradients (single scalar), \(||G||\)
Compute the clipping coefficient,
clip_coeff
: \(\lambda / ||G||\)Clamp the
clip_coeff
to be less than or equal to 1.0Multiply all the gradients by the
clip_coeff
.
clipping_type='value'
#
Value-based gradient clipping is implemented as follows:
On Event.AFTER_TRAIN_BATCH
, for every parameter in the model that has gradients:
Any gradients that are greater than
clipping_threshold
are set toclipping_threshold
and any gradients less than -clipping_threshold
are set to -clipping_threshold
. See here for more details.
clipping_type='adaptive'
#
Adaptive gradient clipping is implemented as follows:
On Event.AFTER_TRAIN_BATCH
, for every parameter in the model that has gradients:
Compute the parameterโs weight norm with an L2 norm (normalized across rows for MLPโs, across entire filters for CNNโs, and across the entire vector for biases), \(||W||\)
Compute the parameterโs gradient norm with an L2 norm, \(||G||\)
If \(||G|| > \lambda||W||\), scale all the contributing gradients by \(\lambda \frac{||W||}{||G||}\).
Suggested Hyperparameters#
Norm-based gradient clipping#
The original authors, R. Pascanu of this type of clipping used gradient clipping with recurrent neural networks. They recommend monitoring the average gradient norm of your modelโs gradients over many iterations as a heuristic to help
figure out a value for the clipping_threshold
.
For computer vision, the authors of the famous Inception convolutional neural network architecture used a clipping_threshold
of 2.0, which they claim helped stabilize their training.
For NLP with transformers, Keskar, et al used a clipping_threshold
of 0.25 for their CTRL, a conditional transformer language model.
The authors of TABERT, a transformer-based BERT model tabular data, recommend a clipping_threshold
of 1.0. The authors of the Compressive Transformer and Gated Convolutional Neural Networks both used a clipping_threshold
of 0.1.
Value-based gradient clipping#
The original author of this type of clipping, Mikolov uses it for training recurrent neural networks and recommends setting the clipping_threshold
to 15. This approach to gradient clipping is not as prevalent as the norm-based clipping and thus to our knowledge there are not very many examples of good settings for clipping_threshold
.
Adaptive gradient clipping#
We havenโt done much experimentation with AGC. However, the original authors, Brock et al. and Ayush Thakur have done some ablations have some recommendations. Note, both parties use AGC with NF-ResNets, which is a variation of ResNets that removes Batch Norm and includes Weight Standardization among other modifications.
Brock et al. recommend using a clipping threshold
of 0.01 for batch sizes between 1024 to 4096.
For smaller batch sizes, AGCโs effects are less pronounced they recommend a larger (less strict) clipping factor
with performance
slightly increasing up to 0.08. They also recommend removing AGC from the last linear layer of the network.
Thakur recommends large clipping threshold
for small batch sizes (at least 0.16 for batch sizes 128 and 256) and smaller clipping threshold
for large batch sizes. They also found that AGC seems to work especially well for the NF-ResNet architecture. Specifically they found that for clipping threshold
of 0.01 and batch size of 1024, AGC does not improve the the performance of a vanilla ResNet with Batch Norm removed.
Attribution#
High-Performance Large-Scale Image Recognition Without Normalization by Andrew Brock, Soham De, Samuel L. Smith, Karen Simonyan. Published in ICML 2021.
On the difficulty of training recurrent neural networks by R. Pascanu, T. Mikolov, and Y. Bengio, 2012
Statistical language models based on neural networks by T. Mikolov
The Composer implementation of this method and the accompanying documentation were produced by Evan Racah at MosaicML.
API Reference#
Algorithm class: composer.algorithms.GradientClipping
Functional: composer.functional.apply_gradient_clipping()