UNet

class composer.models.UNet(hparams: composer.models.unet.unet_hparams.UnetHparams)[source]

Bases: composer.models.base.BaseMosaicModel

A U-Net model extending MosaicClassifier.

See this paper for details on the U-Net architecture.

Parameters

hparams (UnetHparams) – The hyperparameters for constructing the model.

forward(batch: Union[Tuple[Union[Tensor, Tuple[Tensor, ...], List[Tensor]], Union[Tensor, Tuple[Tensor, ...], List[Tensor]]], List[Tensor]]) Tensor[source]

Compute model output given an input.

Parameters

batch (Batch) – The input batch for the forward pass.

Returns

Tensors – The result that is passed to loss() as a Tensors object.

inference2d(image)[source]

Runs inference on a 3D image, by passing each depth slice through the model.

loss(outputs: Any, batch: Union[Tuple[Union[Tensor, Tuple[Tensor, ...], List[Tensor]], Union[Tensor, Tuple[Tensor, ...], List[Tensor]]], List[Tensor]]) Union[Tensor, Tuple[Tensor, ...], List[Tensor]][source]

Compute the loss of the model.

Parameters
  • outputs (Any) – The output of the forward pass.

  • batch (Batch) – The input batch from dataloader.

Returns

Tensors – The loss as a Tensors object.

metrics(train: bool = False) Union[Metric, torchmetrics.collections.MetricCollection][source]

Get metrics for evaluating the model.

Warning

Each metric keeps states which are updated with data seen so far. As a result, different metric instances should be used for training and validation. See: https://torchmetrics.readthedocs.io/en/latest/pages/overview.html for more details.

Parameters

train (bool, optional) – True to return metrics that should be computed during training and False otherwise. (default: False)

Returns

Metrics – A Metrics object.

validate(batch: Union[Tuple[Union[Tensor, Tuple[Tensor, ...], List[Tensor]], Union[Tensor, Tuple[Tensor, ...], List[Tensor]]], List[Tensor]]) Tuple[Any, Any][source]

Compute model outputs on provided data.

The output of this function will be directly used as input to all metrics returned by metrics().

Parameters

batch (Batch) – The data to perform validation with. Specified as a tuple of tensors (input, target).

Returns

Tuple[Any, Any] – Tuple that is passed directly to the update() methods of the metrics returned by metrics(). Most often, this will be a tuple of the form (predictions, targets).