Source code for composer.models.resnet_cifar.model

# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""ResNet models for CIFAR extending :class:`.ComposerClassifier`."""

from typing import List, Optional

from composer.models.initializers import Initializer
from composer.models.resnet_cifar.resnets import ResNet9, ResNetCIFAR
from composer.models.tasks import ComposerClassifier

__all__ = ['composer_resnet_cifar']


[docs]def composer_resnet_cifar(model_name: str, num_classes: int = 10, initializers: Optional[List[Initializer]] = None) -> ComposerClassifier: """Helper function to create a :class:`.ComposerClassifier` with a CIFAR ResNet models. From `Deep Residual Learning for Image Recognition <https://arxiv.org/abs/1512.03385>`_ (He et al, 2015). ResNet9 is based on the model from myrtle.ai `blog`_. Args: model_name (str): ``"resnet_9"``, ``"resnet_20"``, or ``"resnet_56"``. num_classes (int, optional): The number of classes. Needed for classification tasks. Default: ``10``. initializers (List[Initializer], optional): Initializers for the model. ``None`` for no initialization. Default: ``None``. Returns: ComposerModel: instance of :class:`.ComposerClassifier` with a CIFAR ResNet model. Example: .. testcode:: from composer.models import composer_resnet_cifar model = composer_resnet_cifar(model_name="resnet_56") # creates a resnet56 for cifar image classification .. _blog: https://myrtle.ai/learn/how-to-train-your-resnet-4-architecture/ """ if initializers is None: initializers = [] if model_name == 'resnet_9': model = ResNet9(num_classes) # current initializers don't work with this architecture. else: model = ResNetCIFAR.get_model_from_name(model_name, initializers, num_classes) composer_model = ComposerClassifier(module=model, num_classes=num_classes) return composer_model