Source code for composer.models.resnet.model

# Copyright 2021 MosaicML. All Rights Reserved.

"""A :class:`.ComposerClassifier` wrapper around the torchvision implementations of the ResNet model family."""

from typing import List, Optional

from torchvision.models import resnet

from composer.models.initializers import Initializer
from composer.models.tasks import ComposerClassifier

__all__ = ["ComposerResNet"]


[docs]class ComposerResNet(ComposerClassifier): """A :class:`.ComposerClassifier` wrapper around the torchvision implementations of the ResNet model family. From `Deep Residual Learning for Image Recognition <https://arxiv.org/abs/1512.03385>`_ (He et al, 2015). Args: model_name (str): Name of the ResNet model instance. Either [``"resnet18"``, ``"resnet34"``, ``"resnet50"``, ``"resnet101"``, ``"resnet152"``]. num_classes (int, optional): The number of classes. Needed for classification tasks. Default: ``1000``. pretrained (bool, optional): If True, use ImageNet pretrained weights. Default: ``False``. groups (int, optional): Number of filter groups for the 3x3 convolution layer in bottleneck blocks. Default: ``1``. width_per_group (int, optional): Initial width for each convolution group. Width doubles after each stage. Default: ``64``. initializers (List[Initializer], optional): Initializers for the model. ``None`` for no initialization. Default: ``None``. Example: .. testcode:: from composer.models import ComposerResNet model = ComposerResNet(model_name='resnet18') # creates a torchvision resnet18 for image classification """ valid_model_names = ["resnet18", "resnet34", "resnet50", "resnet101", "resnet152"] def __init__( self, model_name: str, num_classes: int = 1000, pretrained: bool = False, groups: int = 1, width_per_group: int = 64, initializers: Optional[List[Initializer]] = None, ) -> None: if model_name not in self.valid_model_names: raise ValueError(f"model_name must be one of {self.valid_model_names} instead of {model_name}.") if initializers is None: initializers = [] model_func = getattr(resnet, model_name) model = model_func(pretrained=pretrained, num_classes=num_classes, groups=groups, width_per_group=width_per_group) for initializer in initializers: initializer = Initializer(initializer) model.apply(initializer.get_initializer()) super().__init__(module=model)