Source code for composer.core.precision

# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Enum class for the numerical precision to be used by the model."""

import contextlib
import textwrap
from typing import Any, Generator, Optional, Union

import torch

from composer.devices import Device, DeviceCPU
from composer.utils import StringEnum, is_xla_installed

try:
    import transformer_engine.pytorch as te
    te_installed = True
except ImportError:
    te_installed = False

__all__ = ['Precision', 'get_precision_context', '_validate_precision']


[docs]class Precision(StringEnum): """Enum class for the numerical precision to be used by the model. Attributes: FP32: Use 32-bit floating-point precision. Compatible with CPUs and GPUs. AMP_FP16: Use :mod:`torch.cuda.amp` with 16-bit floating-point precision. Only compatible with GPUs. AMP_BF16: Use :mod:`torch.cuda.amp` with 16-bit BFloat precision. AMP_FP8: Use :mod:`transformer_engine.pytorch.fp8_autocast` with 8-bit FP8 precison. """ FP32 = 'fp32' AMP_FP16 = 'amp_fp16' AMP_BF16 = 'amp_bf16' AMP_FP8 = 'amp_fp8'
def _validate_precision(precision: Precision, device: Device): """Validate the precision and device combination.""" if isinstance(device, DeviceCPU) and precision != Precision.FP32: raise ValueError(f'{precision} is not supported for CPU training.')
[docs]@contextlib.contextmanager def get_precision_context( precision: Union[str, Precision], precision_config: Optional[dict[str, Any]] = None, fp8_autocast_enabled: bool = True, ) -> Generator[None, None, None]: """Returns a context manager to automatically cast to a specific precision. Args: precision (str | Precision): Precision for the context precision_config (Optional[dict[str, Any]]): Config for FP8 scaling strategy. See parameters for `DelayedScaling <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/common.html?highlight=delayedscaling#transformer_engine.common.recipe.DelayedScaling>`_. fp8_autocast_enabled (bool): Whether to enable FP8 autocast. Defaults to True. """ precision = Precision(precision) if precision == Precision.FP32: if torch.cuda.is_available(): with torch.cuda.amp.autocast(False): yield else: # Yield here to avoid warnings about cuda not being available yield elif precision == Precision.AMP_FP16: # Retain compatibility with PyTorch < 1.10 if torch.cuda.is_available(): with torch.cuda.amp.autocast(True): yield elif is_xla_installed(): with torch.autocast('xla', dtype=torch.float16): yield else: yield elif precision == Precision.AMP_BF16: if torch.cuda.is_available(): with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16): yield elif is_xla_installed(): with torch.autocast('xla', dtype=torch.bfloat16): yield else: yield elif precision == Precision.AMP_FP8: if te_installed and torch.cuda.get_device_capability() >= (8, 9): from transformer_engine.common.recipe import DelayedScaling, Format if precision_config is None: precision_config = { 'fp8_format': Format.HYBRID, 'amax_history_len': 16, 'amax_compute_algo': 'max', } fp8_recipe = DelayedScaling(**precision_config) with te.fp8_autocast(enabled=fp8_autocast_enabled, fp8_recipe=fp8_recipe): # The te.onnx_export flag ensures that we save all fp8 buffers # as tensors instead of bytes. This is necessary for proper # saving and resumption of checkpoints. with te.onnx_export(enabled=True): yield else: if te_installed: raise RuntimeError('AMP_FP8 precision is used but current device does not support it.') else: raise ImportError( textwrap.dedent( """\ AMP_FP8 precision is used but TransformerEngine is not installed. After making sure torch is already installed, please install it using pip install --upgrade git+https://github.com/NVIDIA/TransformerEngine.git@stable""", ), ) else: raise ValueError(f'Unsupported precision: {precision}')