Source code for composer.trainer.devices.device_gpu
# Copyright 2021 MosaicML. All Rights Reserved.
"""The GPU device used for training."""
from __future__ import annotations
from contextlib import contextmanager
from typing import Any, Dict, Generator, TypeVar, Union
import torch
import torch.cuda.amp
import torch.utils.data
from packaging import version
from composer.core.precision import Precision
from composer.trainer.devices.device import Device, T_nnModule
from composer.utils import dist
__all__ = ["DeviceGPU"]
T_nnModule = TypeVar("T_nnModule", bound=torch.nn.Module)
[docs]class DeviceGPU(Device):
"""An extension of :class:`~composer.trainer.devices.device.Device` for GPUs.
This class takes no arguments.
"""
dist_backend = "nccl"
def __init__(self):
gpu = dist.get_local_rank()
self._device = torch.device(f"cuda:{gpu}")
torch.cuda.set_device(self._device)
assert torch.cuda.current_device() == gpu
def module_to_device(self, module: T_nnModule) -> T_nnModule:
return module.to(self._device)
def tensor_to_device(self, tensor: torch.Tensor) -> torch.Tensor:
return tensor.to(self._device, non_blocking=True)
@contextmanager
def precision_context(self, precision: Union[str, Precision]) -> Generator[None, None, None]:
precision = Precision(precision)
enabled = False
if precision == Precision.FP32:
enabled = False
elif precision == Precision.AMP:
enabled = True
elif precision == Precision.BF16:
if version.parse(torch.__version__) < version.parse("1.10"):
raise ValueError(f"BF16 precision requires torch > 1.10, got version {torch.__version__}")
with torch.cuda.amp.autocast(True, torch.bfloat16): # type: ignore
yield
# Retain compatibility with PyTorch < 1.10
if precision != Precision.BF16:
with torch.cuda.amp.autocast(enabled): # type: ignore
yield
def state_dict(self) -> Dict[str, Any]:
return {
"rng": torch.cuda.get_rng_state(),
}
def load_state_dict(self, state: Dict[str, Any]) -> None:
torch.cuda.set_rng_state(state["rng"])