Source code for composer.devices.device_mps

# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""The Apple M-series device used for training."""

from __future__ import annotations

from typing import Any, Dict, TypeVar

import torch
import torch.cuda.amp
import torch.utils.data
from packaging import version

from composer.devices.device import Device

__all__ = ['DeviceMPS']

T_nnModule = TypeVar('T_nnModule', bound=torch.nn.Module)


[docs]class DeviceMPS(Device): """Device to support MPS, for training on Apple's M-series chips. This class takes no arguments. """ dist_backend = '' name = 'mps' def __init__(self): if version.parse(torch.__version__) < version.parse('1.12.0'): raise RuntimeError('Support for MPS device requires torch >= 1.12.') if not torch.backends.mps.is_available(): # type: ignore (version guarded) raise RuntimeError('MPS requires MAC OSX >= 12.3') if not torch.backends.mps.is_built(): # type: ignore (version guarded) raise RuntimeError('torch was not build with MPS support.') self._device = torch.device('mps') def module_to_device(self, module: T_nnModule) -> T_nnModule: return module.to(self._device) def tensor_to_device(self, tensor: torch.Tensor) -> torch.Tensor: return tensor.to(self._device) def state_dict(self) -> Dict[str, Any]: return {} def load_state_dict(self, state: Dict[str, Any]) -> None: if len(state) != 0: raise ValueError('MPS device has no state.')