# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0
"""Core ALiBi classes and functions."""
from __future__ import annotations
import logging
from typing import Optional, Sequence, Union
import torch
from torch.optim import Optimizer
from composer.core import Algorithm, Event, State
from composer.loggers import Logger
from composer.utils import MissingConditionalImportError, module_surgery
log = logging.getLogger(__name__)
__all__ = ['Alibi', 'apply_alibi']
[docs]def apply_alibi(
model: torch.nn.Module,
max_sequence_length: int,
optimizers: Optional[Union[Optimizer, Sequence[Optimizer]]] = None,
) -> None:
"""Removes position embeddings and replaces the attention function and attention mask
as per :class:`.Alibi`. Note that the majority of the training speed-up from using ALiBi
comes from being able to train on shorter sequence lengths; this function does not scale
the training sequence length as :class:`.Alibi` does, so little speedup will be
observed from using it alone. See the :doc:`Method Card </method_cards/alibi>` for
more details. This function should be called after the model is instantiated and
before training begins.
Example:
.. code-block:: python
import composer.functional as cf
cf.apply_alibi(
model=model,
max_sequence_length=512
)
Args:
model (torch.nn.Module): Model to transform.
max_sequence_length (int): Maximum sequence length that the
model will be able to accept. Internally, the transformations applied by alibi
change sequence-shaped tensors to handle sequences up to ``max_sequence_length``.
Depending on ``max_sequence_length`` and ``model`` these changes could increase
or decrease the model's maximum sequence length.
At minimum, ``max_sequence_length`` should be set to the sequence length used
during training. However, if evaluating on sequence lengths longer than those
used in training, ``max_sequence_length`` should be set accordingly.
Note that larger ``max_sequence_length`` means a larger memory footprint of
the model. So, it is best to set this parameter equal the longest
sequence length that will be seen during training and/or evaluation.
optimizers (torch.optim.Optimizer | Sequence[torch.optim.Optimizer], optional):
Existing optimizers bound to ``model.parameters()``. All optimizers that have already been
constructed with ``model.parameters()`` must be specified here so
they will optimize the correct parameters.
If the optimizer(s) are constructed *after* calling this function,
then it is safe to omit this parameter. These optimizers will see the correct
model parameters.
"""
try:
from composer.algorithms.alibi.attention_surgery_functions import policy_registry
except ImportError as e:
raise MissingConditionalImportError(extra_deps_group='nlp', conda_package='transformers') from e
# To use model surgery utilities, we need to define a policy of type
# Mapping[Type[torch.nn.Module], ReplacementFunction], where ReplacementFunction is
# Callable[[torch.nn.Module, Optional[int]], Optional[torch.nn.Module]].
#
# This mapping is built by the source code in `./attention_surgery_functions/` but
# needs to be completed here by "freezing" alibi-specific arguments.
#
# For additional details, see `./attention_surgery_functions/utils.py`.
def as_replacement_function(surgery_function):
def replacement_function(module: torch.nn.Module, module_index: int):
return surgery_function(module, module_index, max_sequence_length=max_sequence_length)
return replacement_function
# Wrap each alibi_surgery_function as a ReplacementFunction by "freezing" `max_sequence_length`
policies = {
module_class: as_replacement_function(alibi_surgery_function)
for module_class, alibi_surgery_function in policy_registry.items()
}
# Note: `policies` defines replacements for _all_ the modules registered in `policy_registry`,
# meaning that some replacements may be irrelevant for `model`.
# Conversely, attention modules within `model` may be ignored if they are not registered by the
# implementations within `./attention_surgery_functions/`.
replaced_pairs = module_surgery.replace_module_classes(model, optimizers=optimizers, policies=policies)
count = len(replaced_pairs)
if count == 0:
supported_modules = ''.join(sorted(['\n\t' + c.__module__ + '.' + c.__name__ for c in policy_registry.keys()]))
log.warning(
f'ALiBi had no effect on the model! Support for ALiBi surgery '
f'is currently limited to the following classes: {supported_modules}',
)
else:
log.info(f' {count} instances of ALiBi added')
[docs]class Alibi(Algorithm):
"""ALiBi (Attention with Linear Biases; `Press et al, 2021 <https://arxiv.org/abs/2108.12409>`_) dispenses with
position embeddings and instead directly biases attention matrices such that nearby tokens attend to one another
more strongly.
ALiBi yields excellent extrapolation to unseen sequence lengths
compared to other position embedding schemes. We leverage this
extrapolation capability by training with shorter sequence lengths,
which reduces the memory and computation load.
This algorithm runs on :attr:`.Event.INIT` to modify the model
before the model has been moved to accelerators. It also runs on
:attr:`.Event.AFTER_DATALOADER` to modify the shape of a batch of
data after the model and data have been moved to accelerators.
See the :doc:`Method Card </method_cards/alibi>` for more details.
Example:
.. code-block::
from composer.algorithms import Alibi
from composer.trainer import Trainer
alibi = Alibi(
max_sequence_length=512,
train_sequence_length_scaling=0.25,
)
trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
max_duration="1ep",
algorithms=[alibi]
)
Args:
max_sequence_length (int): Maximum sequence length that the
model will be able to accept. This is sometimes necessary for evaluating
on sequence lengths longer than the model was initialized to
accommodate.
train_sequence_length_scaling (float, optional): Amount by which to scale
training sequence length. One batch of training data will be
reshaped from shape :math:`(sequence\\_length, batch)` to
:math:`(sequence\\_length \\times train\\_sequence\\_length\\_scaling,
\\frac{batch}{train\\_sequence\\_length\\_scaling})`. Default: ``0.25``.
"""
def __init__(self, max_sequence_length: int, train_sequence_length_scaling: float = 0.25) -> None:
self.max_sequence_length = max_sequence_length
self.train_sequence_length_scaling = train_sequence_length_scaling
self._applied = False
def __repr__(self) -> str:
return f'{self.__class__.__name__}(max_sequence_length={self.max_sequence_length},train_sequence_length_scaling={self.train_sequence_length_scaling})'
@staticmethod
def required_on_load() -> bool:
return True
def match(self, event: Event, state: State) -> bool:
return (event == Event.INIT and not self._applied) or event == Event.AFTER_DATALOADER
def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]:
if event == Event.INIT:
apply_alibi(
state.model,
optimizers=state.optimizers,
max_sequence_length=self.max_sequence_length,
)
self._applied = True
elif event == Event.AFTER_DATALOADER:
# Change sequence length by reshaping data
if not self.train_sequence_length_scaling == 1 and \
hasattr(state, 'batch') and isinstance(state.batch, dict):
sequence_scaling = self.train_sequence_length_scaling
for k, v in state.batch.items():
batch_len, sequence_len = v.shape[0], v.shape[1]
state.batch[k] = v.reshape(int(batch_len / sequence_scaling), int(sequence_len * sequence_scaling))