CrossEntropyLoss

class composer.models.loss.CrossEntropyLoss(dist_sync_on_step=False)[source]

Bases: torchmetrics.metric.Metric

Torchmetric cross entropy loss implementation.

This class implements cross entropy loss as a torchmetric so that it can be returned by the metric() function in BaseMosaicModel.

compute() composer.core.types.Tensor[source]

Aggregate state over all processes and compute the metric.

update(preds: composer.core.types.Tensor, target: composer.core.types.Tensor) None[source]

Update the state with new predictions and targets.