Source code for composer.trainer.devices.device_cpu
# Copyright 2021 MosaicML. All Rights Reserved.
"""The CPU device used for training."""
from __future__ import annotations
import logging
from contextlib import contextmanager
from typing import Any, Dict, Generator, TypeVar, Union
import torch
from composer.core import Precision
from composer.trainer.devices.device import Device, T_nnModule
logger = logging.getLogger(__name__)
__all__ = ["DeviceCPU"]
T_nnModule = TypeVar("T_nnModule", bound=torch.nn.Module)
[docs]class DeviceCPU(Device):
"""An extension of :class:`~composer.trainer.devices.device.Device` for CPUs.
This class takes no arguments.
"""
dist_backend = "gloo"
_device = torch.device('cpu')
def module_to_device(self, module: T_nnModule) -> T_nnModule:
return module.to(self._device)
def tensor_to_device(self, tensor: torch.Tensor) -> torch.Tensor:
return tensor.to(self._device)
@contextmanager
def precision_context(self, precision: Union[str, Precision]) -> Generator[None, None, None]:
precision = Precision(precision)
if precision == Precision.FP32:
yield
else:
raise ValueError(f"Precision {precision} not supported for a CPU")
def state_dict(self) -> Dict[str, Any]:
# CPU device has no RNG state
return {}
def load_state_dict(self, state: Dict[str, Any]) -> None:
if len(state) != 0:
raise ValueError("CPU device has no state.")