# Copyright 2022 MosaicML Composer authors# SPDX-License-Identifier: Apache-2.0"""A wrapper class that converts mmdet detection models to composer models"""from__future__importannotationsfromtypingimportTYPE_CHECKING,Any,List,OptionalimportnumpyasnpimporttorchfromtorchmetricsimportMetricfromtorchmetrics.collectionsimportMetricCollectionfromcomposer.modelsimportComposerModelifTYPE_CHECKING:importmmdet__all__=['MMDetModel']
[docs]classMMDetModel(ComposerModel):"""A wrapper class that adapts mmdetection detectors to composer models. Args: model (mmdet.models.detectors.BaseDetector): An MMdetection Detector. metrics (list[Metric], optional): list of torchmetrics to apply to the output of `eval_forward`. Default: ``None``. .. warning:: This wrapper is designed to work with mmdet datasets. Example: .. code-block:: python from mmdet.models import build_model from mmcv import ConfigDict from composer.models import MMDetModel yolox_s_config = dict( type='YOLOX', input_size=(640, 640), random_size_range=(15, 25), random_size_interval=10, backbone=dict(type='CSPDarknet', deepen_factor=0.33, widen_factor=0.5), neck=dict(type='YOLOXPAFPN', in_channels=[128, 256, 512], out_channels=128, num_csp_blocks=1), bbox_head=dict(type='YOLOXHead', num_classes=num_classes, in_channels=128, feat_channels=128), train_cfg=dict(assigner=dict(type='SimOTAAssigner', center_radius=2.5)), test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.65))) yolox = build_model(ConfigDict(yolox_s_config)) yolox.init_weights() model = MMDetModel(yolox) """def__init__(self,model:mmdet.models.detectors.BaseDetector,# type: ignoremetrics:Optional[List[Metric]]=None)->None:super().__init__()self.model=modelself.train_metrics=Noneself.val_metrics=Noneifmetrics:metric_collection=MetricCollection(metrics)self.train_metrics=metric_collection.clone(prefix='train_')self.val_metrics=metric_collection.clone(prefix='val_')defforward(self,batch):# this will return a dictionary of losses in train mode and model outputs in test mode.returnself.model(**batch)defloss(self,outputs,batch,**kwargs):returnoutputs
[docs]defeval_forward(self,batch,outputs:Optional[Any]=None):""" Args: batch (dict): a eval batch of the format: ``img`` (List[torch.Tensor]): list of image torch.Tensors of shape (batch, c, h , w). ``img_metas`` (List[Dict]): (1, batch_size) list of ``image_meta`` dicts. Returns: model predictions: A batch_size length list of dictionaries containg detection boxes in (x,y, x2, y2) format, class labels, and class probabilities. """device=batch['img'][0].devicebatch.pop('gt_labels')batch.pop('gt_bboxes')results=self.model(return_loss=False,rescale=True,**batch)# models behave differently in eval mode# outputs are a list of bbox results (x, y, x2, y2, score)# pack mmdet bounding boxes and labels into the format for torchmetrics MAP expectspreds=[]forbbox_resultinresults:boxes_scores=np.vstack(bbox_result)boxes,scores=torch.from_numpy(boxes_scores[...,:-1]).to(device),torch.from_numpy(boxes_scores[...,-1]).to(device)labels=[np.full(result.shape[0],i,dtype=np.int32)fori,resultinenumerate(bbox_result)]pred={'labels':torch.from_numpy(np.concatenate(labels)).to(device).long(),'boxes':boxes.float(),'scores':scores.float()}preds.append(pred)returnpreds