HFCrossEntropy#
- class composer.metrics.HFCrossEntropy(dist_sync_on_step=False)[source]#
Hugging Face compatible cross entropy loss.
- Adds metric state variables:
sum_loss (float): The sum of the per-example loss in the batch. total_batches (float): The number of batches to average across.
- Parameters
dist_sync_on_step (bool, optional) โ Synchronize metric state across processes at each forward() before returning the value at the step. Default:
False
- compute()[source]#
Aggregate the state over all processes to compute the metric.
- Returns
loss โ The loss averaged across all batches as a
Tensor
.
- update(output, target)[source]#
Updates the internal state with results from a new batch.
- Parameters
output (Mapping) โ The output from the model, which must contain either the Tensor or a Mapping type that contains the loss or model logits.
target (Tensor) โ A Tensor of ground-truth values to compare against.