Source code for composer.core.precision

# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Enum class for the numerical precision to be used by the model."""

import contextlib
import textwrap
from typing import Any, Dict, Generator, Optional, Union

import torch

from composer.utils import StringEnum, is_xla_installed

    import transformer_engine.pytorch as te
    te_installed = True
except ImportError:
    te_installed = False

__all__ = ['Precision', 'get_precision_context']

[docs]class Precision(StringEnum): """Enum class for the numerical precision to be used by the model. Attributes: FP32: Use 32-bit floating-point precision. Compatible with CPUs and GPUs. AMP_FP16: Use :mod:`torch.cuda.amp` with 16-bit floating-point precision. Only compatible with GPUs. AMP_BF16: Use :mod:`torch.cuda.amp` with 16-bit BFloat precision. AMP_FP8: Use :mod:`transformer_engine.pytorch.fp8_autocast` with 8-bit FP8 precison. """ FP32 = 'fp32' AMP_FP16 = 'amp_fp16' AMP_BF16 = 'amp_bf16' AMP_FP8 = 'amp_fp8'
[docs]@contextlib.contextmanager def get_precision_context(precision: Union[str, Precision], precision_config: Optional[Dict[str, Any]] = None) -> Generator[None, None, None]: """Returns a context manager to automatically cast to a specific precision. Args: precision (str | Precision): Precision for the context precision_config (Optional[Dict[str, Any]]): Config for FP8 scaling strategy. See parameters for `DelayedScaling <>`_. """ precision = Precision(precision) if precision == Precision.FP32: if torch.cuda.is_available(): with torch.cuda.amp.autocast(False): yield else: # Yield here to avoid warnings about cuda not being available yield elif precision == Precision.AMP_FP16: # Retain compatibility with PyTorch < 1.10 if torch.cuda.is_available(): with torch.cuda.amp.autocast(True): yield elif is_xla_installed(): with torch.autocast('xla', dtype=torch.float16): yield else: yield elif precision == Precision.AMP_BF16: if torch.cuda.is_available(): with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16): yield elif is_xla_installed(): with torch.autocast('xla', dtype=torch.bfloat16): yield else: yield elif precision == Precision.AMP_FP8: if te_installed and torch.cuda.get_device_capability() >= (8, 9): from transformer_engine.common.recipe import DelayedScaling, Format if precision_config is None: precision_config = { 'fp8_format': Format.HYBRID, 'amax_history_len': 16, 'amax_compute_algo': 'max', } fp8_recipe = DelayedScaling(**precision_config) with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): # The te.onnx_export flag ensures that we save all fp8 buffers # as tensors instead of bytes. This is necessary for proper # saving and resumption of checkpoints. with te.onnx_export(enabled=True): yield else: if te_installed: raise RuntimeError('AMP_FP8 precision is used but current device does not support it.') else: raise ImportError( textwrap.dedent("""\ AMP_FP8 precision is used but TransformerEngine is not installed. After making sure torch is already installed, please install it using pip install --upgrade git+""")) else: raise ValueError(f'Unsupported precision: {precision}')