Source code for composer.models.vit_small_patch16.hparams

# Copyright 2021 MosaicML. All Rights Reserved.

"""`YAHP <https://docs.mosaicml.com/projects/yahp/en/stable/README.html>`_ interface for :class:`.ViTSmallPatch16`."""

from dataclasses import dataclass

import yahp as hp

from composer.models.model_hparams import ModelHparams
from composer.utils.import_helpers import MissingConditionalImportError

__all__ = ["ViTSmallPatch16Hparams"]


[docs]@dataclass class ViTSmallPatch16Hparams(ModelHparams): """`YAHP <https://docs.mosaicml.com/projects/yahp/en/stable/README.html>`_ interface for :class:`.ViTSmallPatch16`. Args: num_classes (int, optional): number of classes for the model. Default: ``1000``. image_size (int, optional): input image size. If you have rectangular images, make sure your image size is the maximum of the width and height. Default: ``224``. channels (int, optional): number of image channels. Default: ``3``. dropout (float, optional): 0.0 - 1.0 dropout rate. Default: ``0``. embedding_dropout (float, optional): 0.0 - 1.0 embedding dropout rate. Default: ``0``. """ num_classes: int = hp.optional("number of classes. Needed for classification tasks", default=1000) image_size: int = hp.optional( "input image size. If you have rectangular images, make sure your image size is the maximum of the width and height", default=224) channels: int = hp.optional("number of image channels", default=3) dropout: float = hp.optional("dropout rate", default=0.0) embedding_dropout: float = hp.optional("embedding dropout rate", default=0.0) def validate(self): try: import vit_pytorch # type: ignore except ImportError as e: raise MissingConditionalImportError(extra_deps_group="vit", conda_package="vit_pytorch>=0.27") from e def initialize_object(self): from composer.models import ViTSmallPatch16 return ViTSmallPatch16(num_classes=self.num_classes, image_size=self.image_size, channels=self.channels, dropout=self.dropout, embedding_dropout=self.embedding_dropout)