Source code for composer.models.timm.timm_hparams

# Copyright 2021 MosaicML. All Rights Reserved.

"""`YAHP <https://docs.mosaicml.com/projects/yahp/en/stable/README.html>`_ interface for :class:`.Timm`."""

import textwrap
from dataclasses import dataclass
from typing import Optional

import yahp as hp

from composer.models.model_hparams import ModelHparams
from composer.models.timm.model import Timm
from composer.utils.import_helpers import MissingConditionalImportError

__all__ = ["TimmHparams"]


[docs]@dataclass class TimmHparams(ModelHparams): """`YAHP <https://docs.mosaicml.com/projects/yahp/en/stable/README.html>`_ interface for :class:`.Timm`. Args: model_name (str): timm model name e.g: ``"resnet50"``. List of models can be found at `PyTorch Image Models <https://github.com/rwightman/pytorch-image-models>`_. pretrained (bool, optional): Imagenet pretrained. Default: ``False``. num_classes (int, optional): The number of classes. Needed for classification tasks. Default: ``1000``. drop_rate (float, optional): Dropout rate. Default: ``0.0``. drop_path_rate (float, optional): Drop path rate (model default if ``None``). Default: ``None``. drop_block_rate (float, optional): Drop block rate (model default if ``None``). Default: ``None``. global_pool (str, optional): Global pool type, one of (``"fast"``, ``"avg"``, ``"max"``, ``"avgmax"``, ``"avgmaxc"``). Model default if ``None``. Default: ``None``. bn_momentum (float, optional): BatchNorm momentum override (model default if not None). Default: ``None``. bn_eps (float, optional): BatchNorm epsilon override (model default if ``None``). Default: ``None``. """ model_name: str = hp.optional( textwrap.dedent("""\ timm model name e.g: 'resnet50', list of models can be found at https://github.com/rwightman/pytorch-image-models"""), default=None, ) pretrained: bool = hp.optional("imagenet pretrained", default=False) num_classes: int = hp.optional("The number of classes. Needed for classification tasks", default=1000) drop_rate: float = hp.optional("dropout rate", default=0.0) drop_path_rate: Optional[float] = hp.optional("drop path rate (model default if None)", default=None) drop_block_rate: Optional[float] = hp.optional("drop block rate (model default if None)", default=None) global_pool: Optional[str] = hp.optional( "Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.", default=None) bn_momentum: Optional[float] = hp.optional("BatchNorm momentum override (model default if not None)", default=None) bn_eps: Optional[float] = hp.optional("BatchNorm epsilon override (model default if not None)", default=None) def validate(self): if self.model_name is None: try: import timm except ImportError as e: raise MissingConditionalImportError(extra_deps_group="timm", conda_package="timm >=0.5.4") from e raise ValueError(f"model must be one of {timm.models.list_models()}") def initialize_object(self): return Timm(model_name=self.model_name, pretrained=self.pretrained, num_classes=self.num_classes, drop_rate=self.drop_rate, drop_path_rate=self.drop_path_rate, drop_block_rate=self.drop_block_rate, global_pool=self.global_pool, bn_momentum=self.bn_momentum, bn_eps=self.bn_eps)