Source code for composer.models.resnet_cifar.resnet_cifar_hparams
# Copyright 2021 MosaicML. All Rights Reserved.
"""`YAHP <https://docs.mosaicml.com/projects/yahp/en/stable/README.html>`_ interface for
:class:`.ComposerResNetCIFAR`."""
from dataclasses import dataclass
import yahp as hp
from composer.models.model_hparams import ModelHparams
from composer.models.resnet_cifar import ComposerResNetCIFAR
__all__ = ["ResNetCIFARHparams"]
[docs]@dataclass
class ResNetCIFARHparams(ModelHparams):
""":class:`~.hp.Hparams` interface for :class:`.ComposerResNetCIFAR`.
Args:
model_name (str): ``"resnet_9"``, ``"resnet_20"``, or ``"resnet_56"``.
num_classes (int, optional): The number of classes. Needed for classification tasks. Default: ``10``.
initializers (List[Initializer], optional): Initializers for the model. ``None`` for no initialization. Default: ``None``.
"""
model_name: str = hp.optional('"cifar_resnet_9", "cifar_resnet_20" or "cifar_resnet_56"', default=None)
num_classes: int = hp.optional("The number of classes. Needed for classification tasks", default=10)
def validate(self):
if self.model_name is None:
raise ValueError('model name must be one of "cifar_resnet_9", "cifar_resnet_20" or "cifar_resnet_56".')
def initialize_object(self):
return ComposerResNetCIFAR(model_name=self.model_name,
num_classes=self.num_classes,
initializers=self.initializers)