UNet
- class composer.models.UNet(hparams: composer.models.unet.unet_hparams.UnetHparams)[source]
Bases:
composer.models.base.BaseMosaicModel
A U-Net model extending
MosaicClassifier
.See this paper for details on the U-Net architecture.
- Parameters
hparams (UnetHparams) – The hyperparameters for constructing the model.
- forward(batch: Union[Tuple[Union[Tensor, Tuple[Tensor, ...], List[Tensor]], Union[Tensor, Tuple[Tensor, ...], List[Tensor]]], List[Tensor]]) Tensor [source]
Compute model output given an input.
- inference2d(image)[source]
Runs inference on a 3D image, by passing each depth slice through the model.
- loss(outputs: Any, batch: Union[Tuple[Union[Tensor, Tuple[Tensor, ...], List[Tensor]], Union[Tensor, Tuple[Tensor, ...], List[Tensor]]], List[Tensor]]) Union[Tensor, Tuple[Tensor, ...], List[Tensor]] [source]
Compute the loss of the model.
- Parameters
outputs (Any) – The output of the forward pass.
batch (Batch) – The input batch from dataloader.
- Returns
Tensors – The loss as a
Tensors
object.
- metrics(train: bool = False) Union[Metric, torchmetrics.collections.MetricCollection] [source]
Get metrics for evaluating the model.
Warning
Each metric keeps states which are updated with data seen so far. As a result, different metric instances should be used for training and validation. See: https://torchmetrics.readthedocs.io/en/latest/pages/overview.html for more details.
- Parameters
train (bool, optional) – True to return metrics that should be computed during training and False otherwise. (default:
False
)- Returns
Metrics – A
Metrics
object.
- validate(batch: Union[Tuple[Union[Tensor, Tuple[Tensor, ...], List[Tensor]], Union[Tensor, Tuple[Tensor, ...], List[Tensor]]], List[Tensor]]) Tuple[Any, Any] [source]
Compute model outputs on provided data.
The output of this function will be directly used as input to all metrics returned by
metrics()
.