Source code for composer.models.unet.model

# Copyright 2021 MosaicML. All Rights Reserved.

"""The Unet architecture used in image segmentation. The example we are using is for BRATS medical brain tumor dataset.

See the :doc:`Model Card </model_cards/unet>` for more details.
"""

import torch.nn as nn

from composer.models.unet._layers import ConvBlock, OutputBlock, ResidBlock, UpsampleBlock

__all__ = ["UNet"]


[docs]class UNet(nn.Module): """Unet Architecture adapted from NVidia `Deep Learning Examples`_. .. _Deep Learning Examples: https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/Segmentation/nnUNet/ Args: in_channels (int): Number of input channels. n_class (int): Number of output layers. kernels (list): Conv layer kernel sizes. strides (list): Conv layer strides. normalization_layer (str): Normalization layer type, one of (``"batch"``, ``"instance"``). negative_slope (float): Leaky relu negative slope. residual (bool): Use residual connections. dimension (int): Filter dimensions. """ def __init__( self, in_channels, n_class, kernels, strides, normalization_layer, negative_slope, residual, dimension, ): super(UNet, self).__init__() self.dim = dimension self.n_class = n_class self.residual = residual self.negative_slope = negative_slope self.norm = normalization_layer + f"norm{dimension}d" self.filters = [min(2**(5 + i), 320 if dimension == 3 else 512) for i in range(len(strides))] down_block = ResidBlock if self.residual else ConvBlock self.input_block = self.get_conv_block( conv_block=down_block, in_channels=in_channels, out_channels=self.filters[0], kernel_size=kernels[0], stride=strides[0], ) self.downsamples = self.get_module_list( conv_block=down_block, in_channels=self.filters[:-1], out_channels=self.filters[1:], kernels=kernels[1:-1], strides=strides[1:-1], ) self.bottleneck = self.get_conv_block( conv_block=down_block, in_channels=self.filters[-2], out_channels=self.filters[-1], kernel_size=kernels[-1], stride=strides[-1], ) self.upsamples = self.get_module_list( conv_block=UpsampleBlock, in_channels=self.filters[1:][::-1], out_channels=self.filters[:-1][::-1], kernels=kernels[1:][::-1], strides=strides[1:][::-1], ) self.output_block = self.get_output_block(decoder_level=0) self.apply(self.initialize_weights) self.n_layers = len(self.upsamples) - 1 def forward(self, input_data): out = self.input_block(input_data) encoder_outputs = [out] for downsample in self.downsamples: out = downsample(out) encoder_outputs.append(out) out = self.bottleneck(out) for idx, upsample in enumerate(self.upsamples): out = upsample(out, encoder_outputs[self.n_layers - idx]) out = self.output_block(out) return out def get_conv_block(self, conv_block, in_channels, out_channels, kernel_size, stride): return conv_block( dim=self.dim, stride=stride, norm=self.norm, kernel_size=kernel_size, in_channels=in_channels, out_channels=out_channels, negative_slope=self.negative_slope, ) def get_output_block(self, decoder_level): return OutputBlock(in_channels=self.filters[decoder_level], out_channels=self.n_class, dim=self.dim) def get_module_list(self, in_channels, out_channels, kernels, strides, conv_block): layers = [] for in_channel, out_channel, kernel, stride in zip(in_channels, out_channels, kernels, strides): conv_layer = self.get_conv_block(conv_block, in_channel, out_channel, kernel, stride) layers.append(conv_layer) return nn.ModuleList(layers) def initialize_weights(self, module): name = module.__class__.__name__.lower() if name in ["conv2d"]: nn.init.kaiming_normal_(module.weight, a=self.negative_slope)