composer_timm#
- composer.models.composer_timm(model_name, pretrained=False, num_classes=1000, drop_rate=0.0, drop_path_rate=None, drop_block_rate=None, global_pool=None, bn_momentum=None, bn_eps=None)[source]#
A wrapper around timm.create_model() used to create
ComposerClassifier.- Parameters
model_name (str) โ timm model name e.g:
"resnet50". List of models can be found at PyTorch Image Models.pretrained (bool, optional) โ Imagenet pretrained. Default:
False.num_classes (int, optional) โ The number of classes. Needed for classification tasks. Default:
1000.drop_rate (float, optional) โ Dropout rate. Default:
0.0.drop_path_rate (float, optional) โ Drop path rate (model default if
None). Default:None.drop_block_rate (float, optional) โ Drop block rate (model default if
None). Default:None.global_pool (str, optional) โ Global pool type, one of (
"fast","avg","max","avgmax","avgmaxc"). Model default ifNone. Default:None.bn_momentum (float, optional) โ BatchNorm momentum override (model default if
None). Default:None.bn_eps (float, optional) โ BatchNorm epsilon override (model default if
None). Default:None.
- Returns
ComposerModel โ instance of
ComposerClassifierwith the specified TIMM model.
Resnet18 Example:
from composer.models import composer_timm model = composer_timm(model_name='resnet18') # creates a timm resnet18