Source code for composer.models.resnet_cifar.model
# Copyright 2021 MosaicML. All Rights Reserved.
"""ResNet models for CIFAR extending :class:`.ComposerClassifier`."""
from typing import List, Optional
from composer.models.initializers import Initializer
from composer.models.resnet_cifar.resnets import ResNet9, ResNetCIFAR
from composer.models.tasks import ComposerClassifier
__all__ = ["ComposerResNetCIFAR"]
[docs]class ComposerResNetCIFAR(ComposerClassifier):
"""ResNet models for CIFAR10 extending :class:`.ComposerClassifier`.
From `Deep Residual Learning for Image Recognition <https://arxiv.org/abs/1512.03385>`_ (He et al, 2015).
ResNet9 is based on the model from myrtle.ai `blog`_.
Args:
model_name (str): ``"resnet_9"``, ``"resnet_20"``, or ``"resnet_56"``.
num_classes (int, optional): The number of classes. Needed for classification tasks. Default: ``10``.
initializers (List[Initializer], optional): Initializers for the model. ``None`` for no initialization. Default: ``None``.
Example:
.. testcode::
from composer.models import ComposerResNetCIFAR
model = ComposerResNetCIFAR(model_name="resnet_56") # creates a resnet56 for cifar image classification
.. _blog: https://myrtle.ai/learn/how-to-train-your-resnet-4-architecture/
"""
def __init__(
self,
model_name: str,
num_classes: int = 10,
initializers: Optional[List[Initializer]] = None,
) -> None:
if initializers is None:
initializers = []
if model_name == "resnet_9":
model = ResNet9(num_classes) # current initializers don't work with this architecture.
else:
model = ResNetCIFAR.get_model_from_name(
model_name,
initializers,
num_classes,
)
super().__init__(module=model)