"""Single Shot Object Detection model with pretrained ResNet34 backbone extending :class:`.ComposerModel`."""

import os
import tempfile
from typing import Any, Sequence, Tuple, Union

import numpy as np
import requests
from torch import Tensor
from torchmetrics import Metric, MetricCollection

from composer.core.types import BatchPair
from composer.models.base import ComposerModel
from composer.models.ssd.base_model import Loss
from composer.models.ssd.ssd300 import SSD300
from composer.models.ssd.utils import Encoder, SSDTransformer, dboxes300_coco
from composer.utils.import_helpers import MissingConditionalImportError

__all__ = ["SSD"]

[docs]class SSD(ComposerModel): """Single Shot Object detection Model with pretrained ResNet34 backbone extending :class:`.ComposerModel`. Args: input_size (int, optional): input image size. Default: ``300``. num_classes (int, optional): The number of classes to detect. Default: ``80``. overlap_threshold (float, optional): Minimum IOU threshold for NMS. Default: ``0.5``. nms_max_detections (int, optional): Max number of boxes after NMS. Default: ``200``. data (str, optional): path to coco dataset. Default: ``"/localdisk/coco"``. """ def __init__(self, input_size: int, overlap_threshold: float, nms_max_detections: int, num_classes: int, data: str): super().__init__() self.input_size = input_size self.overlap_threshold = overlap_threshold self.nms_max_detections = nms_max_detections self.num_classes = num_classes url = "" with tempfile.TemporaryDirectory() as tempdir: with requests.get(url, stream=True) as r: r.raise_for_status() pretrained_backbone = os.path.join(tempdir, "weights.pth") with open(pretrained_backbone, 'wb') as f: for chunk in r.iter_content(chunk_size=8192): f.write(chunk) self.module = SSD300(self.num_classes, model_path=pretrained_backbone) dboxes = dboxes300_coco() self.loss_func = Loss(dboxes) self.encoder = Encoder(dboxes) = data self.MAP = coco_map( val_annotate = os.path.join(, "annotations/instances_val2017.json") val_coco_root = os.path.join(, "val2017") input_size = self.input_size val_trans = SSDTransformer(dboxes, (input_size, input_size), val=True) from composer.datasets.coco import COCODetection self.val_coco = COCODetection(val_coco_root, val_annotate, val_trans) def loss(self, outputs: Any, batch: BatchPair) -> Union[Tensor, Sequence[Tensor]]: (_, _, _, bbox, label) = batch #type: ignore if not isinstance(bbox, Tensor): raise TypeError("bbox must be a singular tensor") trans_bbox = bbox.transpose(1, 2).contiguous() ploc, plabel = outputs gloc, glabel = trans_bbox, label loss = self.loss_func(ploc, plabel, gloc, glabel) return loss def metrics(self, train: bool = False) -> Union[Metric, MetricCollection]: return self.MAP def forward(self, batch: BatchPair) -> Tensor: (img, _, _, _, _) = batch #type: ignore ploc, plabel = self.module(img) return ploc, plabel #type: ignore def validate(self, batch: BatchPair) -> Tuple[Any, Any]: inv_map = {v: k for k, v in self.val_coco.label_map.items()} ret = [] overlap_threshold = self.overlap_threshold nms_max_detections = self.nms_max_detections (img, img_id, img_size, _, _) = batch #type: ignore ploc, plabel = self.module(img) results = [] try: results = self.encoder.decode_batch(ploc, plabel, overlap_threshold, nms_max_detections, nms_valid_thresh=0.05) except: print("No object detected") (htot, wtot) = [d.cpu().numpy() for d in img_size] #type: ignore img_id = img_id.cpu().numpy() #type: ignore if len(results) > 0: # Iterate over batch elements for img_id_, wtot_, htot_, result in zip(img_id, wtot, htot, results): loc, label, prob = [r.cpu().numpy() for r in result] #type: ignore # Iterate over image detections for loc_, label_, prob_ in zip(loc, label, prob): ret.append([img_id_, loc_[0]*wtot_, \ loc_[1]*htot_, (loc_[2] - loc_[0])*wtot_, (loc_[3] - loc_[1])*htot_, prob_, inv_map[label_]]) return ret, ret
class coco_map(Metric): def __init__(self, data): super().__init__() try: from pycocotools.coco import COCO except ImportError as e: raise MissingConditionalImportError(extra_deps_group="coco", conda_channel="conda-forge", conda_package="pycocotools") from e self.add_state("predictions", default=[]) val_annotate = os.path.join(data, "annotations/instances_val2017.json") self.cocogt = COCO(annotation_file=val_annotate) def update(self, pred, target): self.predictions.append(pred) #type: ignore np.squeeze(self.predictions) #type: ignore def compute(self): try: from pycocotools.cocoeval import COCOeval except ImportError as e: raise MissingConditionalImportError(extra_deps_group="coco", conda_channel="conda-forge", conda_package="pycocotools") from e cocoDt = self.cocogt.loadRes(np.array(self.predictions)) E = COCOeval(self.cocogt, cocoDt, iouType='bbox') E.evaluate() E.accumulate() E.summarize() return E.stats[0]