๐ฐ Fused LayerNorm#
[How to Use] - [Suggested Hyperparameters] - [Technical Details] - [Attribution]
Natural Language Processing
Fused LayerNorm replaces implementations of torch.nn.LayerNorm
with a apex.normalization.fused_layer_norm
. The fused kernel provides increased GPU utilization.
A visualization of the impact of Fused LayerNorm. |
How to Use#
Functional Interface#
# Apply surgery on the model to swap-in the Fused LayerNorm using the Composer functional API
import composer.functional as cf
def training_loop(model, train_loader):
cf.apply_fused_layernorm(model)
opt = torch.optim.Adam(model.parameters())
loss_fn = F.cross_entropy
model.train()
for X, y in train_loader:
y_hat = model(X)
loss = loss_fn(y_hat, y)
loss.backward()
opt.step()
opt.zero_grad()
Composer Trainer#
from composer.trainer import Trainer
from composer.algorithms import FusedLayerNorm
trainer = Trainer(model=model,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
max_duration='1ep',
algorithms=[FusedLayerNorm()])
trainer.fit()
Implementation Details#
Fused LayerNorm is implemented by performing model surgery, which looks for instances of torch.nn.LayerNorm
and replaces them with a apex.normalization.fused_layer_norm
. This should be applicable to any model that utilizes a torch.nn.LayerNorm
.
Suggested Hyperparameters#
Fused LayerNorm does not have any hyperparameters. It utilizes the existing normalized_shape
and d_eps
from the original model.
Technical Details#
APEXโs FusedLayerNorm achieves a substantial speedup over PyTorch by doing a few things:
Instead of a naive implementation, which requires two passes over the input in order to estimate variances, it uses Welfordโs Online Algorithm to estimate the variances in a single step, creating a substantive wall-clock speedup.
Instead of requiring multiple CUDA kernel launches, it computes everything in a single kernel launch, therefore improving GPU utilization.
Attribution#
The Composer implementation of this method and the accompanying documentation were produced by Moin Nadeem at MosaicML.