Source code for composer.algorithms.algorithm_hparams
# Copyright 2021 MosaicML. All Rights Reserved.
from __future__ import annotations
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional
import yahp as hp
import composer
from composer.core.algorithm import Algorithm
[docs]@dataclass
class AlgorithmHparams(hp.Hparams, ABC):
"""Hyperparameters for algorithms."""
[docs] @abstractmethod
def initialize_object(self) -> Algorithm:
"""Invoked by the :meth:`TrainerHparams.initialize_object` to create an instance of the :class:`Algorithm`.
Returns:
Algorithm: An instance of the :class:`Algorithm`.
"""
pass
@classmethod
def load(cls, alg_params: Optional[str] = None) -> AlgorithmHparams:
from composer.algorithms.algorithm_registry import get_algorithm_registry
registry = get_algorithm_registry()
inverted_registry = {v: k for (k, v) in registry.items()}
alg_name = inverted_registry[cls]
alg_folder = os.path.join(os.path.dirname(composer.__file__), "yamls", "algorithms")
if alg_params is None:
hparams_file = os.path.join(alg_folder, f"{alg_name}.yaml")
else:
hparams_file = os.path.join(alg_folder, alg_name, f"{alg_params}.yaml")
if os.path.exists(hparams_file):
alg_hparams = cls.create(hparams_file, cli_args=False)
assert isinstance(alg_hparams, AlgorithmHparams), "hparams.create should return an instance of its type"
return alg_hparams
return cls()
@classmethod
def load_multiple(cls, *algorithms: str):
from composer.algorithms.algorithm_registry import get_algorithm_registry
registry = get_algorithm_registry()
alg_hparams = []
for alg in algorithms:
alg_parts = alg.split("/")
alg_name = alg_parts[0]
if len(alg_parts) > 1:
alg_params = "/".join(alg_parts[1:])
else:
alg_params = None
try:
alg_param = registry[alg_name]
except KeyError as e:
raise ValueError(f"Algorithm {e.args[0]} not found") from e
alg_hparams.append(alg_param.load(alg_params))
return alg_hparams