Source code for composer.metrics.nlp

# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""A collection of common torchmetrics for NLP tasks."""

import copy
import functools
import logging
import os
import re
import string
import warnings
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union

import numpy as np
import torch
from torch import Tensor
from torch.nn import functional as F
from torchmetrics import Metric

from composer.utils import dist
from composer.utils.eval_client import EvalClient, LambdaEvalClient, LocalEvalClient, MosaicMLLambdaEvalClient
from composer.utils.warnings import VersionedDeprecationWarning

log = logging.getLogger(__name__)

__all__ = [
    'InContextLearningMetric',
    'InContextLearningLMAccuracy',
    'InContextLearningMultipleChoiceAccuracy',
    'InContextLearningQAAccuracy',
    'InContextLearningCodeEvalAccuracy',
    'BinaryF1Score',
    'LanguageCrossEntropy',
    'MaskedAccuracy',
    'LanguagePerplexity',
    'InContextLearningLMExpectedCalibrationError',
    'InContextLearningMCExpectedCalibrationError',
]


[docs]class MaskedAccuracy(Metric): """Computes accuracy with support for masked indices. Adds metric state variables: correct (float): The number of instances where the prediction masked the target. total (float): The number of total instances that were predicted. Args: ignore_index (int): The class index to ignore. Default: -100. dist_sync_on_step (bool, optional): Synchronize metric state across processes at each forward() before returning the value at the step. Default: ``False``. """ # Make torchmetrics call update only once full_state_update = False def __init__(self, ignore_index: int = -100, dist_sync_on_step: bool = False): # state from multiple processes super().__init__(dist_sync_on_step=dist_sync_on_step) self.ignore_index = ignore_index self.add_state('correct', default=torch.tensor(0), dist_reduce_fx='sum') self.add_state('total', default=torch.tensor(0), dist_reduce_fx='sum') def update(self, preds: torch.Tensor, target: torch.Tensor): # predictions is a batch x num_classes tensor, take the argmax to get class indices preds = torch.argmax(preds, dim=-1) assert preds.shape == target.shape # mask out the padded indices mask = (target != self.ignore_index) masked_target = target[mask] masked_preds = preds[mask] self.correct += torch.sum(masked_preds == masked_target) self.total += mask.sum() def compute(self): assert isinstance(self.correct, Tensor) assert isinstance(self.total, Tensor) return self.correct.float() / self.total
[docs]class LanguageCrossEntropy(Metric): """Torchmetric that computes cross entropy on language modeling outputs. Adds metric state variables: sum_loss (float): The sum of the per-example loss in the batch. total_items (float): The number of batches to average across. Args: dist_sync_on_step (bool, optional): Synchronize metric state across processes at each forward() before returning the value at the step. Default: ``False``. ignore_index (int, optional): The class index to ignore. Default: ``-100``. """ # Make torchmetrics call update only once full_state_update = False def __init__(self, dist_sync_on_step: bool = False, ignore_index: int = -100): super().__init__(dist_sync_on_step=dist_sync_on_step) self.ignore_index = ignore_index self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=ignore_index, reduction='sum') self.add_state('sum_loss', default=torch.tensor(0.), dist_reduce_fx='sum') self.add_state('total_items', default=torch.tensor(0), dist_reduce_fx='sum')
[docs] def update(self, output: Union[Mapping, Tensor], target: Tensor) -> None: """Updates the internal state with results from a new batch. Args: output (Mapping): The output from the model, which must contain either the Tensor or a Mapping type that contains the loss or model logits. target (~torch.Tensor): A Tensor of ground-truth values to compare against. """ if isinstance(output, Mapping): logits = output['logits'] elif isinstance(output, Tensor): logits = output else: raise Exception(f'Type {type(output)} for the output is unsupported.') target = target.view(-1) logits = logits.view(target.shape[0], -1) losses = self.loss_fn(logits, target) total_items = (target != self.ignore_index).sum() self.total_items += total_items #type: ignore (third-party) # accumulate loss over all batches self.sum_loss += losses
[docs] def compute(self) -> Tensor: """Aggregate the state over all processes to compute the metric. Returns: loss: The loss averaged across all batches as a :class:`~torch.Tensor`. """ # Return average loss over entire dataset return self.sum_loss / self.total_items #type: ignore (third-party)
[docs]class BinaryF1Score(Metric): """Implements F1 Scores for binary classification tasks via sklearn. Adds metric state variables: true_positive (float): A counter of how many items were correctly classified as positives. false_positive (float): A counter of how many items were incorrectly classified as positives. false_negative (float): A counter of how many items were incorrectly classified as negatives. Args: dist_sync_on_step (bool, optional): Synchronize metric state across processes at each forward() before returning the value at the step. Default: ``False``. """ # Make torchmetrics call update only once full_state_update = False def __init__(self, dist_sync_on_step: bool = False): super().__init__(dist_sync_on_step=dist_sync_on_step) self.add_state('true_positive', default=torch.tensor(0.), dist_reduce_fx='sum') self.add_state('false_positive', default=torch.tensor(0), dist_reduce_fx='sum') self.add_state('false_negative', default=torch.tensor(0), dist_reduce_fx='sum')
[docs] def update(self, output: Tensor, target: Tensor) -> None: """Updates the internal state with results from a new batch. Args: output (Mapping): The output from the model, which must contain either the Tensor or a Mapping type that contains the loss or model logits. target (~torch.Tensor): A Tensor of ground-truth values to compare against. """ predictions = torch.argmax(output, dim=1) self.true_positive += predictions[(target == 1)].sum() self.false_positive += (predictions[(target == 1)] == 0).sum() self.false_negative += (predictions[(target == 0)] == 1).sum()
[docs] def compute(self) -> Tensor: """Aggregate the state over all processes to compute the metric. Returns: loss: The loss averaged across all batches as a :class:`~torch.Tensor`. """ assert isinstance(self.true_positive, Tensor) assert isinstance(self.false_positive, Tensor) assert isinstance(self.false_negative, Tensor) f1 = (self.true_positive) / (self.true_positive + (0.5 * (self.false_negative + self.false_positive))) return f1
[docs]class LanguagePerplexity(LanguageCrossEntropy): """Subclasses :class:`~composer.metrics.nlp.LanguageCrossEntropy` to implement perplexity."""
[docs] def compute(self) -> Tensor: """Returns torch.exp() of the LanguageCrossEntropy.""" avg_loss = super().compute() return torch.exp(avg_loss)
[docs]class InContextLearningMetric(Metric): """Base class for In-context learning (ICL) metrics.""" def __init__(self, *args, **kwargs): warnings.warn( VersionedDeprecationWarning( '`InContextLearningMetric` and it\'s subclasses have been deprecated and ' + 'migrated to MosaicML\'s llm-foundry repo under the llmfoundry.eval.datasets.in_context_learning module: ' + 'https://github.com/mosaicml/llm-foundry/blob/main/scripts/eval/README.md', remove_version='0.23.0', ), ) super().__init__(*args, **kwargs) self.needs_batch = True def _wrap_update(self, update: Callable) -> Callable: """Overwrite default _wrap_update to return result of update(). Torch metrics wraps update with following wrapped_func but explicitly does not return the value. In general, torchmetrics update() does not return a value, but we want to in order to pass it on to state.metric_outputs. """ @functools.wraps(update) def wrapped_func(*args: Any, **kwargs: Any) -> None: self._computed = None self._update_count += 1 with torch.set_grad_enabled(self._enable_grad): try: update_result = update(*args, **kwargs) except RuntimeError as err: if 'Expected all tensors to be on' in str(err): raise RuntimeError( 'Encountered different devices in metric calculation (see stacktrace for details).' ' This could be due to the metric class not being on the same device as input.' f' Instead of `metric={self.__class__.__name__}(...)` try to do' f' `metric={self.__class__.__name__}(...).to(device)` where' ' device corresponds to the device of the input.', ) from err raise err if self.compute_on_cpu: self._move_list_states_to_cpu() return update_result return wrapped_func
[docs] def update( self, batch: dict, output_logits: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, outputs: Optional[torch.Tensor] = None, ): """Abstract interface for computing an in-context learning metrics. The `output_logits` argument is deprecated and will be removed in v0.22 while it's functionality will be moved to `outputs`. Args: batch (dict): Batch must consist minimally of `input_ids` as well as any other structure needed to compute the metric. output_logits (torch.Tensor): The model outputs evaluated on the batch `input_ids` labels (torch.Tensor): The correct outputs. outputs (torch.Tensor): The model outputs evaluated on the batch `input_ids`. Raises: NotImplementedError: Abstract method must be implemented by subclasses """ raise NotImplementedError
@staticmethod def rename_args( batch: dict, output_logits: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, outputs: Optional[torch.Tensor] = None, ) -> Tuple[dict, torch.Tensor, torch.Tensor]: if outputs is not None and output_logits is not None: raise ValueError('Cannot use both `outputs` and `output_logits`') if output_logits is not None: warnings.warn( VersionedDeprecationWarning('`output_logits` has been renamed to `outputs`.', remove_version='0.23.0'), ) outputs = output_logits if labels is None: raise ValueError('`labels` cannot be None') if outputs is None: raise ValueError('`outputs` cannot be None') return batch, outputs, labels
[docs]class InContextLearningQAAccuracy(InContextLearningMetric): r"""Computes accuracy for In-context learning (ICL) question answering (QA) tasks. ICL QA tasks consist of some number of example question answering tasks (referred to as the 'context'), followed by a test task where the model must match one of the possible answer aliases (referred to as the 'continuation'). For example, the model may be provided the context below and evaluated on its ability to correctly predict the continuation. Context: `Question: Who was president of the United States in 2012?\nAnswer: Barack Obama\nQuestion: Is water wet?\nAnswer: ` Continuation: [`yes`, `no`] Both predictions and answers will be normalized before comparison. Adds metric state variables: correct (float): The number of instances where the prediction was a prefix for any of the answer aliases. total (float): The number of total instances that were predicted. Args: dist_sync_on_step (bool, optional): Synchronize metric state across processes at each forward() before returning the value at the step. Default: ``False``. """ # Make torchmetrics call update only once full_state_update = False def __init__(self, dist_sync_on_step: bool = False): # state from multiple processes super().__init__(dist_sync_on_step=dist_sync_on_step) self.add_state('correct', default=torch.tensor(0.), dist_reduce_fx='sum') self.add_state('total', default=torch.tensor(0.), dist_reduce_fx='sum') self.metric_result_dict = { 'cleaned_output': [], 'original_label': [], 'cleaned_label': [], 'result': [], }
[docs] def normalize_answer(self, answer: str): """Lower text and remove punctuation, articles and extra whitespace. Copied from https://github.com/mandarjoshi90/triviaqa/blob/master/evaluation/triviaqa_evaluation.py """ def remove_articles(text: str) -> str: return re.sub(r'\b(a|an|the)\b', ' ', text) def white_space_fix(text: str) -> str: return ' '.join(text.split()) def handle_punc(text: str) -> str: exclude = set(string.punctuation + ''.join([u'โ€˜', u'โ€™', u'ยด', u'`'])) return ''.join(ch if ch not in exclude else ' ' for ch in text) def lower(text: str) -> str: return text.lower() def replace_underscore(text: str) -> str: return text.replace('_', ' ') return white_space_fix(remove_articles(handle_punc(lower(replace_underscore(answer))))).strip()
def update(self, outputs: List[str], labels: List[List[str]], batch: Dict[str, Any]): cot_delimiter = batch.get('cot_delimiter', '') do_normalization = batch.get('do_normalization', True) stopping_criteria = batch.get('stopping_criteria', None) metric_result_dict = copy.deepcopy(self.metric_result_dict) for sample_output, sample_labels in zip(outputs, labels): final_answer = sample_output if stopping_criteria is not None and len(stopping_criteria) > 0: final_answer = re.split('|'.join(stopping_criteria), final_answer)[0] if cot_delimiter is not None and len(cot_delimiter) > 0: final_answer = final_answer.split(cot_delimiter)[-1] if do_normalization: cleaned_final_answer = self.normalize_answer(final_answer) cleaned_sample_labels = {self.normalize_answer(label) for label in sample_labels} else: # even if normalization is off, we should still strip leading/trailing whitespaces cleaned_final_answer = final_answer.strip() cleaned_sample_labels = {sample_label.strip() for sample_label in sample_labels} metric_result_dict['original_label'].append(sample_labels) metric_result_dict['cleaned_output'].append(cleaned_final_answer) metric_result_dict['cleaned_label'].append(cleaned_sample_labels) if any(cleaned_final_answer.startswith(label) for label in cleaned_sample_labels): self.correct += torch.tensor(1.0) metric_result_dict['result'].append(1) else: metric_result_dict['result'].append(0) self.total += torch.tensor(1.0) return metric_result_dict def compute(self): assert isinstance(self.correct, Tensor) assert isinstance(self.total, Tensor) return self.correct / self.total
[docs]class InContextLearningLMAccuracy(InContextLearningMetric): r"""Computes accuracy for In-context learning (ICL) language modeling (LM) tasks. ICL LM tasks consist of some number of example language modeling tasks (referred to as the 'context'), followed by a test task where the model must correctly predict all the tokens following tokens in some passage (referred to as the 'continuation'). For example, the model may be provided the context below and evaluated on its ability to correctly predict the continuation. Note: it doesn't matter whether the model correctly predicts the context tokens. Context: `The dog is->fuzzy\nthe water is->hot\nthe tree is->` Continuation: `green` Adds metric state variables: correct (float): The number of instances where the prediction masked the target. total (float): The number of total instances that were predicted. Args: dist_sync_on_step (bool, optional): Synchronize metric state across processes at each forward() before returning the value at the step. Default: ``False``. """ # Make torchmetrics call update only once full_state_update = False def __init__(self, dist_sync_on_step: bool = False): # state from multiple processes super().__init__(dist_sync_on_step=dist_sync_on_step) self.add_state('correct', default=torch.tensor(0.), dist_reduce_fx='sum') self.add_state('total', default=torch.tensor(0.), dist_reduce_fx='sum') self.metric_result_dict = {'context': [], 'label': [], 'output': [], 'result': []} def update( self, batch: dict, output_logits: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, outputs: Optional[torch.Tensor] = None, ): batch, outputs, labels = InContextLearningMetric.rename_args( batch=batch, output_logits=output_logits, labels=labels, outputs=outputs, ) metric_result_dict = copy.deepcopy(self.metric_result_dict) for batch_idx, cont_idx in enumerate(batch['continuation_indices']): cont_tok_pred = outputs[batch_idx].index_select(dim=0, index=cont_idx - 1).argmax(dim=-1) cont_tok_targ = labels[batch_idx].index_select(dim=0, index=cont_idx - 1) metric_result_dict['context'].append(batch['input_ids'][batch_idx][:cont_idx[0]]) metric_result_dict['label'].append(cont_tok_targ) metric_result_dict['output'].append(cont_tok_pred) correct = (cont_tok_pred == cont_tok_targ).all().int() self.correct += correct metric_result_dict['result'].append(int(correct)) self.total += torch.tensor(1.0) return metric_result_dict def compute(self): assert isinstance(self.correct, Tensor) assert isinstance(self.total, Tensor) return self.correct / self.total
[docs]class InContextLearningMultipleChoiceAccuracy(InContextLearningMetric): r"""Computes accuracy for In-context learning (ICL) multiple choice (MC) tasks. ICL MC tasks consists of a series of questions with some number of possible choices (only one of which can be correct). At inference time each possible choice is given to the model as a separate input and the one for which the model assigns the lowest perplexity to the choice is considered the model's choice. The model is correct if it "chooses" the right answer. Context: `The dog is->fuzzy\nthe water is->hot\nthe tree is->` Continuation: `green` Adds metric state variables: correct (float): The number of instances where the prediction masked the target. total (float): The number of total instances that were predicted. Args: dist_sync_on_step (bool, optional): Synchronize metric state across processes at each forward() before returning the value at the step. Default: ``False``. """ # Make torchmetrics call update only once full_state_update = False def __init__(self, dist_sync_on_step: bool = False): # state from multiple processes super().__init__(dist_sync_on_step=dist_sync_on_step) self.add_state('correct', default=torch.tensor(0.0), dist_reduce_fx='sum') self.add_state('total', default=torch.tensor(0.0), dist_reduce_fx='sum') self.metric_result_dict = { 'context': [], 'correct_choice': [], 'correct_choice_idx': [], 'selected_choice': [], 'selected_choice_idx': [], 'all_choices': [], 'result': [], } def update( self, batch: dict, output_logits: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, outputs: Optional[torch.Tensor] = None, ): batch, outputs, labels = InContextLearningMetric.rename_args( batch=batch, output_logits=output_logits, labels=labels, outputs=outputs, ) perplexities = [] for batch_idx, cont_idx in enumerate(batch['continuation_indices']): # continuation indices refer to indices in the original input's token space cont_tok_logits = outputs[batch_idx].index_select(dim=0, index=cont_idx - 1) # labels have been shifted left by one index, so the cont_idx needs to be shifted as well. cont_tok_targ = labels[batch_idx].index_select(dim=0, index=cont_idx - 1) cross_entropy = F.cross_entropy(cont_tok_logits, cont_tok_targ) perplexity = torch.exp(cross_entropy) perplexities.append(perplexity) metric_result_dict = copy.deepcopy(self.metric_result_dict) for (start, end), gold_idx in zip(batch['choice_groupings'], batch['gold_indices']): subset = perplexities[start:end] idx_min = subset.index(min(subset)) if idx_min == gold_idx: self.correct += torch.tensor(1.0) metric_result_dict['result'].append(1) else: metric_result_dict['result'].append(0) question = batch['input_ids'][start][:batch['continuation_indices'][start][0]] correct_choice = batch['input_ids'][start:end][gold_idx][batch['continuation_indices'][start:end][gold_idx][ 0]:batch['continuation_indices'][start:end][gold_idx][-1] + 1] selected_choice = batch['input_ids'][start:end][idx_min][batch['continuation_indices'][start:end][idx_min][ 0]:batch['continuation_indices'][start:end][idx_min][-1] + 1] metric_result_dict['context'].append(question) metric_result_dict['correct_choice'].append(correct_choice) metric_result_dict['correct_choice_idx'].append(gold_idx) metric_result_dict['selected_choice'].append(selected_choice) metric_result_dict['selected_choice_idx'].append(idx_min) all_choices = batch['input_ids'][start:end] # Unpads the choices. Necessary in case different choices have different token lengths. if 'attention_mask' in batch: all_choices_list = [choice[batch['attention_mask'][i]] for i, choice in enumerate(all_choices)] metric_result_dict['all_choices'].append(all_choices_list) self.total += torch.tensor(1.0) # Don't return all_choices if we didn't fill it up (i.e. didn't use causal lms) if metric_result_dict['all_choices'] == []: metric_result_dict.pop('all_choices') return metric_result_dict def compute(self): assert isinstance(self.correct, Tensor) assert isinstance(self.total, Tensor) return self.correct.float() / self.total
class InContextLearningExpectedCalibrationError(InContextLearningMetric): """Generic class for Expected Calibration Error (ECE) (cite: https://arxiv.org/pdf/1706.04599.pdf). Expected calibration error is calculated by dividing predictions into buckets based on the model's confidence (a probability value between 0 and 1). We then calculate the accuracy within each bucket and calculate the average gap between confidence and accuracy across buckets, weighted by the number of samples in each bucket. Each task must implement its own definition of "confidence" to be computed via the `update` method. Adds metric state variables: bucket_totals (float): The number of instances where the prediction masked the target per bucket. bucket_correct (float): The number of total instances that were predicted per bucket. Args: dist_sync_on_step (bool, optional): Synchronize metric state across processes at each forward() before returning the value at the step. Default: ``False``. n_buckets (int): Number of distinct buckets to split the confidence distribution into """ def __init__(self, dist_sync_on_step: bool = False, n_buckets: int = 10): # state from multiple processes super().__init__(dist_sync_on_step=dist_sync_on_step) self.n_buckets = n_buckets if n_buckets < 1: raise Exception('`n_buckets`') self.add_state('bucket_totals', default=torch.zeros(n_buckets), dist_reduce_fx='sum') self.add_state('bucket_correct', default=torch.zeros(n_buckets), dist_reduce_fx='sum') def update(self, batch: dict, output_logits: torch.Tensor, labels: torch.Tensor): pass def compute(self): assert isinstance(self.bucket_correct, Tensor) assert isinstance(self.bucket_totals, Tensor) result = torch.tensor(0.0, device=self.bucket_correct.device) total_obs = torch.sum(self.bucket_totals) for i in range(self.n_buckets): if self.bucket_totals[i] == 0: continue acc_bucket_i = self.bucket_correct[i] / self.bucket_totals[i] upper_bound = (i + 1) / self.n_buckets lower_bound = i / self.n_buckets conf_bucket_i = torch.tensor((upper_bound + lower_bound) / 2, device=self.bucket_correct.device) result += (self.bucket_totals[i] / total_obs) * torch.abs(acc_bucket_i - conf_bucket_i) return result
[docs]class InContextLearningMCExpectedCalibrationError(InContextLearningExpectedCalibrationError): r"""Computes Expected Calibration Error (ECE) for In-context learning (ICL) multiple choice (MC) tasks. (source: https://arxiv.org/abs/2012.00955). For MC tasks, the model confidence is defined as the softmax of average per-token probability assigned to the top question choice. See `InContextLearningExpectedCalibrationError` for more info. """ # Make torchmetrics call update only once full_state_update = False def update( self, batch: dict, output_logits: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, outputs: Optional[torch.Tensor] = None, ): batch, outputs, labels = InContextLearningMetric.rename_args( batch=batch, output_logits=output_logits, labels=labels, outputs=outputs, ) outputs = torch.softmax(outputs, dim=2) probabilites = [] for batch_idx, cont_idx in enumerate(batch['continuation_indices']): cont_tok_logits = outputs[batch_idx].index_select(dim=0, index=cont_idx - 1) cont_tok_targ = labels[batch_idx].index_select(dim=0, index=cont_idx - 1) probability = cont_tok_logits.index_select(dim=1, index=cont_tok_targ).diagonal().mean() probabilites.append(probability) for (start, end), gold_idx in zip(batch['choice_groupings'], batch['gold_indices']): subset = probabilites[start:end] idx_max = subset.index(max(subset)) confidence = torch.tensor(subset).max() / torch.tensor(subset).sum() assert confidence >= 0.0 and confidence <= 1.0 bucket_idx = int(confidence * self.n_buckets) if bucket_idx == self.n_buckets: bucket_idx -= 1 if idx_max == gold_idx: self.bucket_correct[bucket_idx] += 1 # pyright: ignore [reportGeneralTypeIssues] self.bucket_totals[bucket_idx] += 1 # pyright: ignore [reportGeneralTypeIssues]
[docs]class InContextLearningLMExpectedCalibrationError(InContextLearningExpectedCalibrationError): r"""Computes Expected Calibration Error (ECE) for In-context learning (ICL) language modeling (LM) tasks. (cite: https://arxiv.org/pdf/1706.04599.pdf). For LM tasks, the model confidence is defined as the minimum probability assigned to all tokens in the continuation. See `InContextLearningExpectedCalibrationError` for more info. """ # Make torchmetrics call update only once full_state_update = False def update( self, batch: dict, output_logits: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, outputs: Optional[torch.Tensor] = None, ): batch, outputs, labels = InContextLearningMetric.rename_args( batch=batch, output_logits=output_logits, labels=labels, outputs=outputs, ) outputs = torch.softmax(outputs, dim=2) for batch_idx, cont_idx in enumerate(batch['continuation_indices']): cont_tok_logits = outputs[batch_idx].index_select(dim=0, index=cont_idx - 1) cont_tok_pred = cont_tok_logits.argmax(dim=-1) confidence = cont_tok_logits.max(dim=-1).values.min() cont_tok_targ = labels[batch_idx].index_select(dim=0, index=cont_idx - 1) assert confidence >= 0.0 and confidence <= 1.0 bucket_idx = int(confidence * self.n_buckets) if bucket_idx == self.n_buckets: bucket_idx -= 1 if (cont_tok_pred == cont_tok_targ).all(): self.bucket_correct[bucket_idx] += 1 # pyright: ignore [reportGeneralTypeIssues] self.bucket_totals[bucket_idx] += 1 # pyright: ignore [reportGeneralTypeIssues]
[docs]class InContextLearningCodeEvalAccuracy(InContextLearningMetric): r"""Computes accuracy for In-context learning (ICL) code evaluation tasks. ICL code eval tasks consist of some number of example code eval tasks (referred to as the 'context'), followed by a test task where the model must complete the code, where we term the code completion a 'continuation'. In each case, the model constructs a given number of continuations (termed pass@K for K continuations), and each continuation is run against a set of test cases. The model is considered correct if at least one of the proposed continuations passes all the test cases. Runs on AWS Lambdas by default. Adds metric state variables: correct (float): The number of instances where the predictions passed all the test cases. total (float): The number of total instances that were predicted. Args: dist_sync_on_step (bool, optional): Synchronize metric state across processes at each forward() before returning the value at the step. Default: ``False``. """ # Make torchmetrics call update only once full_state_update = False def __init__(self, dist_sync_on_step: bool = False): # state from multiple processes super().__init__(dist_sync_on_step=dist_sync_on_step) self._initialized = False self.eval_device = os.environ.get('CODE_EVAL_DEVICE', None) if self.eval_device is not None: self.eval_device = self.eval_device.upper() self.metric_result_dict = {'context': [], 'output': [], 'result': [], 'sample_id': []}
[docs] def get_client(self) -> EvalClient: """Returns a client for the appropriate remote platform.""" client = None if self.eval_device == 'LOCAL': warnings.warn( 'Running code eval locally may be insecure. Please set environment variable CODE_EVAL_DEVICE ' 'to LAMBDA to run on remote. To use Lambdas, spin up your instance that checks code, set the URL as ' 'CODE_EVAL_URL and the API key as CODE_EVAL_APIKEY.', ) log.debug('Running code eval locally.') client = LocalEvalClient() elif self.eval_device == 'LAMBDA': client = LambdaEvalClient() elif self.eval_device == 'MOSAICML': client = MosaicMLLambdaEvalClient() elif self.eval_device is None: raise ValueError( 'Attempting to use InContextLearningCodeEvalAccuracy but environment ' 'variable `CODE_EVAL_DEVICE` is not set. Please set it to `CODE_EVAL_DEVICE` ' 'to one of `LOCAL` (for unsafe local eval), `LAMBDA` (for AWS lambda ', 'evaluation), or `MOSAICML` (for lambda eval through MAPI).', ) else: raise ValueError( 'Environment variable `CODE_EVAL_DEVICE` must be one of `LOCAL`, ' f'`LAMBDA`, or `MOSAICML` but got {self.eval_device}.', ) return client
[docs] def estimator(self, n: int, c: int, k: int) -> float: """Computes the pass@k metric. Given the number of generated samples, n, the number of correct samples, c, and the k of interest, this function calculates pass@k as 1 - comb(n - c, k) / comb(n, k) as per the definition of pass@k in the HumanEval paper (https://arxiv.org/abs/2107.03374) and it's associated implementation: https://github.com/openai/human-eval. """ if n - c < k: return 1.0 return 1.0 - float(np.prod(1.0 - k / np.arange(n - c + 1, n + 1)))
def _initialize_state(self, batch: dict[str, Any]): device = batch['input_ids'].device self.dataset_size = batch['dataset_size'] self.pass_at_k = batch['pass_at_k'] self.num_generations = batch['generations_per_sample'] # We need to defer the accumulator initialization because it depends on dataset size self.add_state('correct', default=torch.zeros(self.dataset_size, device=device), dist_reduce_fx='sum') self.add_state('total', default=torch.zeros(self.dataset_size, device=device), dist_reduce_fx='sum') dist.barrier() self._initialized = True
[docs] def update(self, batch: Dict[str, Any], outputs: List[str], labels: List[str]): """Updates the pass@k accuracy of code generation. Given a batch of prompts, test cases, and code generations, evaluates the code generations against the test cases and augments the pass@k accuracy of the batch to the values so far. Args: batch (Dict[str, Any]): A batch of data produced by the InContextLearningCodeEvalDataset, with the prompt, test cases, and entry points. This will be a dictionary that must have the following arguments: { 'prompts': List[str], 'test_inputs': List[List[str]], 'test_outputs': List[List[str]], 'entry_points': List[str], 'languages': List[str], 'generation_kwargs': Dict[str, Any] } outputs (List[str]): A list of code generations in the format of HF generate with beam search, which is the a list of strings in groups of beam_size e.g. for beam size 2 and batch size 2, the list will be of the format [prompt 1 gen 1, prompt 1 gen 2, prompt 2 gen 1, prompt 2 gen 2] labels (List[str]): A list of the correct code generations, for compatibility with existing HF generate functionalities. This is not used. """ if not self._initialized: self._initialize_state(batch) del labels # never used client = self.get_client() metric_result_dict = copy.deepcopy(self.metric_result_dict) for sample_id, code_gen, sample_prompt, test_inputs, test_outputs, entry_point, language in zip( batch['sample_id'], outputs, batch['prompts'], batch['test_inputs'], batch['test_outputs'], batch['entry_points'], batch['languages'], ): idx = sample_id self.total[idx] += 1.0 metric_result_dict['sample_id'].append(sample_id) code_gen = re.split(r'\n[A-Za-z0-9#`]', code_gen)[0] # remove everything after function ends final_code = sample_prompt + code_gen # combine prompt with the code generation metric_result_dict['context'].append(sample_prompt) metric_result_dict['output'].append(code_gen) test_results = [] for test_input, test_output in zip(test_inputs, test_outputs): payload = { 'code': final_code, 'input': test_input, 'output': test_output, 'entry_point': entry_point, 'language': language, } result = client.invoke([[[payload]]])[0][0][0] test_results.append(result) if all(test_results): self.correct[idx] += 1.0 metric_result_dict['result'].append(1) else: metric_result_dict['result'].append(0) client.close() # pyright: ignore [reportOptionalMemberAccess] return metric_result_dict
def compute(self): assert isinstance(self.correct, Tensor) assert isinstance(self.total, Tensor) complete = self.total == self.num_generations # so that eval subset batches can be used if complete.sum() < (self.total != 0).sum(): warnings.warn( 'Some samples in the dataset have less than the expected number of generations. ' 'This is expected if you are using a subset of the dataset for evaluation.', ) if (self.correct > self.total).any().item(): raise ValueError( 'Internal error some samples have more correct than total generations. This should not happen.', ) results = {} n = self.num_generations for k in self.pass_at_k: estimators = [self.estimator(n, int(c.item()), k) for c in self.correct[complete]] pass_at_k = sum(estimators) / complete.sum().item() results[f'pass@{k}'] = torch.tensor(pass_at_k) if len(results) == 1: # backwards compatibility return list(results.values())[0] return results