Source code for composer.models.unet.unet_hparams
# Copyright 2021 MosaicML. All Rights Reserved.
"""`YAHP <https://docs.mosaicml.com/projects/yahp/en/stable/README.html>`_ interface for
:class:`~composer.models.unet.unet.UNet`."""
from dataclasses import dataclass
from composer.models.model_hparams import ModelHparams
__all__ = ["UnetHparams"]
[docs]@dataclass
class UnetHparams(ModelHparams):
"""`YAHP <https://docs.mosaicml.com/projects/yahp/en/stable/README.html>`_ interface for
:class:`~composer.models.unet.unet.UNet`.
Args:
num_classes (int, optional): The number of classes. Needed for classification tasks. Default: ``3``.
"""
def initialize_object(self):
from composer.models.unet.unet import UNet
assert self.num_classes is not None, "num_classes must be specified."
return UNet(num_classes=self.num_classes)