Source code for composer.models.resnet_cifar.resnet_cifar_hparams

# Copyright 2021 MosaicML. All Rights Reserved.

"""`YAHP <https://docs.mosaicml.com/projects/yahp/en/stable/README.html>`_ interface for
:class:`.ComposerResNetCIFAR`."""
from dataclasses import dataclass

import yahp as hp

from composer.models.model_hparams import ModelHparams
from composer.models.resnet_cifar import ComposerResNetCIFAR

__all__ = ["ResNetCIFARHparams"]


[docs]@dataclass class ResNetCIFARHparams(ModelHparams): """:class:`~.hp.Hparams` interface for :class:`.ComposerResNetCIFAR`. Args: model_name (str): ``"resnet_9"``, ``"resnet_20"``, or ``"resnet_56"``. num_classes (int, optional): The number of classes. Needed for classification tasks. Default: ``10``. initializers (List[Initializer], optional): Initializers for the model. ``None`` for no initialization. Default: ``None``. """ model_name: str = hp.optional('"cifar_resnet_9", "cifar_resnet_20" or "cifar_resnet_56"', default=None) num_classes: int = hp.optional("The number of classes. Needed for classification tasks", default=10) def validate(self): if self.model_name is None: raise ValueError('model name must be one of "cifar_resnet_9", "cifar_resnet_20" or "cifar_resnet_56".') def initialize_object(self): return ComposerResNetCIFAR(model_name=self.model_name, num_classes=self.num_classes, initializers=self.initializers)