🍰 Fused LayerNorm#
Natural Language Processing,
Fused LayerNorm replaces implementations of
torch.nn.LayerNorm with a
apex.normalization.fused_layer_norm. The fused kernel provides increased GPU utilization.
A visualization of the impact of Fused LayerNorm.
How to Use#
# Apply surgery on the model to swap-in the Fused LayerNorm using the Composer functional API import composer.functional as cf def training_loop(model, train_loader): cf.apply_fused_layernorm(model) opt = torch.optim.Adam(model.parameters()) loss_fn = F.cross_entropy model.train() for X, y in train_loader: y_hat = model(X) loss = loss_fn(y_hat, y) loss.backward() opt.step() opt.zero_grad()
from composer.trainer import Trainer from composer.algorithms import FusedLayerNorm trainer = Trainer(model=model, train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, max_duration='1ep', algorithms=[FusedLayerNorm()]) trainer.fit()
Fused LayerNorm is implemented by performing model surgery, which looks for instances of
torch.nn.LayerNorm and replaces them with a
apex.normalization.fused_layer_norm. This should be applicable to any model that utilizes a
Fused LayerNorm does not have any hyperparameters. It utilizes the existing
d_eps from the original model.
APEX’s FusedLayerNorm achieves a substantial speedup over PyTorch by doing a few things:
Instead of a naive implementation, which requires two passes over the input in order to estimate variances, it uses Welford’s Online Algorithm to estimate the variances in a single step, creating a substantive wall-clock speedup.
Instead of requiring multiple CUDA kernel launches, it computes everything in a single kernel launch, therefore improving GPU utilization.
✅ Fused LayerNorm Improves Training Speed
In our experiments, Fused LayerNorm improves the attainable tradeoffs between training speed and the final quality of the trained model. We recommend using Fused LayerNorm.
The Composer implementation of this method and the accompanying documentation were produced by Moin Nadeem at MosaicML.