Source code for composer.models.unet.unet

# Copyright 2021 MosaicML. All Rights Reserved.

import contextlib
import logging
from typing import Any, Optional, Tuple

import torch
import torch.nn as nn

from composer.core.types import BatchPair, Metrics, Tensor, Tensors
from composer.models.base import BaseMosaicModel
from composer.models.loss import Dice
from composer.models.unet.model import UNet as UNetModel
from composer.models.unet.unet_hparams import UnetHparams

log = logging.getLogger(__name__)


[docs]class UNet(BaseMosaicModel): """A U-Net model extending :class:`MosaicClassifier`. See this `paper <https://arxiv.org/abs/1505.04597>`_ for details on the U-Net architecture. Args: hparams (UnetHparams): The hyperparameters for constructing the model. """ n_classes: Optional[int] = None def __init__(self, hparams: UnetHparams) -> None: super().__init__() from monai.losses import DiceLoss self.hparams = hparams self.module = self.build_nnunet() self.dice = Dice(3) self.dloss = DiceLoss(include_background=False, softmax=True, to_onehot_y=True, batch=True) self.closs = nn.CrossEntropyLoss()
[docs] def loss(self, outputs: Any, batch: BatchPair) -> Tensors: _, y = batch y = y.squeeze(1) # type: ignore assert isinstance(y, Tensor) loss = self.dloss(outputs, y) loss += self.closs(outputs, y[:, 0].long()) return loss
@staticmethod def metric_mean(name, outputs): return torch.stack([out[name] for out in outputs]).mean(dim=0)
[docs] def metrics(self, train: bool = False) -> Metrics: return self.dice
[docs] def forward(self, batch: BatchPair) -> Tensor: x, y = batch context = contextlib.nullcontext if self.training else torch.no_grad x = x.squeeze(1) # type: ignore with context(): logits = self.module(x) return logits
[docs] def inference2d(self, image): """Runs inference on a 3D image, by passing each depth slice through the model.""" batch_modulo = image.shape[2] % 64 if batch_modulo != 0: batch_pad = 64 - batch_modulo image = nn.ConstantPad3d((0, 0, 0, 0, batch_pad, 0), 0)(image) image = torch.transpose(image.squeeze(0), 0, 1) preds_shape = (image.shape[0], 4, *image.shape[2:]) preds = torch.zeros(preds_shape, dtype=image.dtype, device=image.device) for start in range(0, image.shape[0] - 64 + 1, 64): end = start + 64 with torch.no_grad(): pred = self.module(image[start:end]) preds[start:end] = pred.data if batch_modulo != 0: preds = preds[batch_pad:] # type: ignore return torch.transpose(preds, 0, 1).unsqueeze(0)
[docs] def validate(self, batch: BatchPair) -> Tuple[Any, Any]: assert self.training is False, "For validation, model must be in eval mode" img, lbl = batch pred = self.inference2d(img) return pred, lbl[:, 0].long() # type: ignore
def build_nnunet(self) -> torch.nn.Module: kernels = [[3, 3]] * 6 strides = [[1, 1]] + [[2, 2]] * 5 model = UNetModel(in_channels=4, n_class=4, kernels=kernels, strides=strides, dimension=2, residual=True, normalization_layer="instance", negative_slope=0.01) return model