Source code for composer.models.bert.model

# Copyright 2021 MosaicML. All Rights Reserved.

"""Implements a BERT wrapper around a :class:`.ComposerTransformer`."""

from __future__ import annotations

from typing import TYPE_CHECKING, Mapping, Optional, Sequence, Union

import torch
from torchmetrics import MeanSquaredError, Metric, MetricCollection
from torchmetrics.classification.accuracy import Accuracy
from torchmetrics.classification.matthews_corrcoef import MatthewsCorrCoef
from torchmetrics.regression.spearman import SpearmanCorrCoef

from composer.metrics.nlp import BinaryF1Score, LanguageCrossEntropy, MaskedAccuracy
from composer.models.transformer_shared import ComposerTransformer

if TYPE_CHECKING:
    import transformers

    from composer.core.types import Batch, BatchDict, BatchPair

__all__ = ["BERTModel"]


[docs]class BERTModel(ComposerTransformer): """BERT model based on |:hugging_face:| Transformers. For more information, see `Transformers <https://huggingface.co/transformers/>`_. Args: module (transformers.BertModel): An instance of BertModel that contains the forward pass function. config (transformers.BertConfig): The BertConfig object that stores information about the model hyperparameters. tokenizer (transformers.BertTokenizer): An instance of BertTokenizer. Necessary to process model inputs. To create a BERT model for Language Model pretraining: .. testcode:: from composer.models import BERTModel import transformers config = transformers.BertConfig() hf_model = transformers.BertLMHeadModel(config=config) tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased") model = BERTModel(module=hf_model, config=config, tokenizer=tokenizer) """ def __init__(self, module: transformers.BertModel, config: transformers.BertConfig, tokenizer: Optional[transformers.BertTokenizer] = None) -> None: if tokenizer is None: model_inputs = {"input_ids", "attention_mask", "token_type_ids"} else: model_inputs = set(tokenizer.model_input_names) super().__init__( module=module, #type: ignore (thirdparty) config=config, model_inputs=model_inputs) # we're going to remove the label from the expected inputs # since we will handle metric calculation with TorchMetrics instead of HuggingFace. self.model_inputs.remove("labels") # When using Evaluators, the validation metrics represent all possible # validation metrics that can be used with the bert model # The Evaluator class checks if it's metrics are in the models validation metrics ignore_index = -100 self.val_metrics = [ Accuracy(), MeanSquaredError(), SpearmanCorrCoef(), BinaryF1Score(), MatthewsCorrCoef(num_classes=config.num_labels), LanguageCrossEntropy(ignore_index=ignore_index, vocab_size=config.num_labels), MaskedAccuracy(ignore_index=ignore_index), ] self.train_metrics = [] def loss(self, outputs: Mapping, batch: Batch) -> Union[torch.Tensor, Sequence[torch.Tensor]]: if outputs.get('loss', None) is not None: return outputs['loss'] else: raise NotImplementedError('Calculating loss directly not supported yet.')
[docs] def validate(self, batch: BatchDict) -> BatchPair: """Runs the validation step. Args: batch (BatchDict): a dictionary of Dict[str, Tensor] of inputs that the model expects, as found in :meth:`.ComposerTransformer.get_model_inputs`. Returns: tuple (Tensor, Tensor): with the output from the forward pass and the correct labels. This is fed into directly into the output of :meth:`.ComposerModel.metrics`. """ assert self.training is False, "For validation, model must be in eval mode" # temporary hack until eval on multiple datasets is finished labels = batch.pop('labels') output = self.forward(batch) output = output['logits'] # if we are in the single class case, then remove the classes dimension if output.shape[1] == 1: output = output.squeeze(dim=1) return output, labels
def metrics(self, train: bool = False) -> Union[Metric, MetricCollection]: return MetricCollection(self.train_metrics) if train else MetricCollection(self.val_metrics)