Source code for composer.models.unet.unet

# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""A U-Net model extending :class:`.ComposerModel`."""

import logging
from typing import Any, Sequence, Tuple, Union

import torch
import torch.nn as nn
from torchmetrics import Metric, MetricCollection

from composer.metrics.metrics import Dice
from composer.models.base import ComposerModel
from composer.models.unet.model import UNet as UNetModel
from composer.utils.import_helpers import MissingConditionalImportError

log = logging.getLogger(__name__)

__all__ = ["UNet"]


[docs]class UNet(ComposerModel): """A U-Net model extending :class:`.ComposerModel`. See U-Net: Convolutional Networks for Biomedical Image Segmentation (`Ronneberger et al, 2015`_) on the U-Net architecture. Args: num_classes (int, optional): The number of classes. Needed for classification tasks. Default: ``3``. .. _Ronneberger et al, 2015: https://arxiv.org/abs/1505.04597 """ def __init__(self, num_classes: int = 3) -> None: super().__init__() try: from monai.losses import DiceLoss except ImportError as e: raise MissingConditionalImportError(extra_deps_group="unet", conda_package="monai", conda_channel="conda-forge") from e self.module = self.build_nnunet() self.dice = Dice(num_classes=num_classes) self.dloss = DiceLoss(include_background=False, softmax=True, to_onehot_y=True, batch=True) self.closs = nn.CrossEntropyLoss() def loss(self, outputs: Any, batch: Any, *args, **kwargs) -> Union[torch.Tensor, Sequence[torch.Tensor]]: _, y = batch y = y.squeeze(1) # type: ignore loss = self.dloss(outputs, y) loss += self.closs(outputs, y[:, 0].long()) return loss @staticmethod def metric_mean(name, outputs): return torch.stack([out[name] for out in outputs]).mean(dim=0) def metrics(self, train: bool = False) -> Union[Metric, MetricCollection]: return self.dice def forward(self, batch: Any) -> torch.Tensor: x, _ = batch x = x.squeeze(1) # type: ignore logits = self.module(x) return logits
[docs] def inference2d(self, image): """Runs inference on a 3D image, by passing each depth slice through the model.""" batch_modulo = image.shape[2] % 64 if batch_modulo != 0: batch_pad = 64 - batch_modulo image = nn.ConstantPad3d((0, 0, 0, 0, batch_pad, 0), 0)(image) image = torch.transpose(image.squeeze(0), 0, 1) preds_shape = (image.shape[0], 4, *image.shape[2:]) preds = torch.zeros(preds_shape, dtype=image.dtype, device=image.device) for start in range(0, image.shape[0] - 64 + 1, 64): end = start + 64 with torch.no_grad(): pred = self.module(image[start:end]) preds[start:end] = pred.data if batch_modulo != 0: preds = preds[batch_pad:] # type: ignore return torch.transpose(preds, 0, 1).unsqueeze(0)
def validate(self, batch: Any) -> Tuple[Any, Any]: assert self.training is False, "For validation, model must be in eval mode" image, target = batch pred = self.inference2d(image) return pred, target[:, 0].long() # type: ignore def build_nnunet(self) -> torch.nn.Module: kernels = [[3, 3]] * 6 strides = [[1, 1]] + [[2, 2]] * 5 model = UNetModel(in_channels=4, n_class=4, kernels=kernels, strides=strides, dimension=2, residual=True, normalization_layer="batch", negative_slope=0.01) return model