Source code for composer.trainer.devices.device_hparams
# Copyright 2021 MosaicML. All Rights Reserved.
"""The :class:`~yahp.hparams.Hparams` used to construct devices."""
from __future__ import annotations
from abc import abstractmethod
from dataclasses import dataclass
import yahp as hp
from composer.trainer.devices.device import Device
from composer.trainer.devices.device_cpu import DeviceCPU
from composer.trainer.devices.device_gpu import DeviceGPU
__all__ = ["DeviceHparams", "CPUDeviceHparams", "GPUDeviceHparams"]
[docs]@dataclass
class DeviceHparams(hp.Hparams):
"""Base for :class:`.CPUDeviceHparams` and :class:`.GPUDeviceHparams`"""
@abstractmethod
def initialize_object(self) -> Device:
pass
[docs]@dataclass
class GPUDeviceHparams(DeviceHparams):
"""Used to construct a :class:`.DeviceGPU`"""
def initialize_object(self) -> DeviceGPU:
return DeviceGPU()
[docs]@dataclass
class CPUDeviceHparams(DeviceHparams):
"""Used to construct a :class:`.DeviceCPU`"""
def initialize_object(self) -> DeviceCPU:
return DeviceCPU()