# Copyright 2021 MosaicML. All Rights Reserved.
"""`YAHP <https://docs.mosaicml.com/projects/yahp/en/stable/README.html>`_ general and classification interfaces for
:class:`.BERTModel`."""
from dataclasses import dataclass
from typing import TYPE_CHECKING
import yahp as hp
from composer.models.transformer_hparams import TransformerHparams
from composer.utils import MissingConditionalImportError
if TYPE_CHECKING:
from composer.models.bert import BERTModel
__all__ = ["BERTForClassificationHparams", "BERTHparams"]
[docs]@dataclass
class BERTForClassificationHparams(TransformerHparams):
"""`YAHP <https://docs.mosaicml.com/projects/yahp/en/stable/README.html>`_ classification interface for
:class:`.BERTModel`.
Args:
pretrained_model_name (str): Pretrained model name to pull from Hugging Face Model Hub.
model_config (Dict[str, JSON]): A dictionary providing a HuggingFace model configuration.
tokenizer_name (Optional[str]): The tokenizer used for this model,
necessary to assert required model inputs.
use_pretrained (bool, optional): Whether to initialize the model with the pretrained weights.
gradient_checkpointing (bool, optional): Use gradient checkpointing. Default: ``False``.
num_labels (int, optional): The number of classes in the segmentation task. Default: ``2``.
"""
num_labels: int = hp.optional(doc="The number of possible labels for the task.", default=2)
def validate(self):
if self.num_labels < 1:
raise ValueError("The number of target labels must be at least one.")
def initialize_object(self) -> "BERTModel":
try:
import transformers
except ImportError as e:
raise MissingConditionalImportError(extra_deps_group="nlp", conda_package="transformers") from e
from composer.models.bert.model import BERTModel
self.validate()
model_hparams = {"num_labels": self.num_labels}
if self.model_config:
config = transformers.BertConfig.from_dict(self.model_config, **model_hparams)
elif self.pretrained_model_name is not None:
config = transformers.BertConfig.from_pretrained(self.pretrained_model_name, **model_hparams)
else:
raise ValueError('One of pretrained_model_name or model_config needed.')
config.num_labels = self.num_labels
# setup the tokenizer in the hparams interface
if self.tokenizer_name is not None:
tokenizer = transformers.BertTokenizer.from_pretrained(self.tokenizer_name)
else:
tokenizer = None
if self.use_pretrained:
# TODO (Moin): handle the warnings on not using the seq_relationship head
model = transformers.AutoModelForSequenceClassification.from_pretrained(self.pretrained_model_name,
**model_hparams)
else:
# an invariant to ensure that we don't lose keys when creating the HF config
for k, v in model_hparams.items():
assert getattr(config, k) == v
model = transformers.AutoModelForSequenceClassification.from_config( #type: ignore (thirdparty)
config)
return BERTModel(
module=model,
config=config, #type: ignore (thirdparty)
tokenizer=tokenizer,
)
[docs]@dataclass
class BERTHparams(TransformerHparams):
"""`YAHP <https://docs.mosaicml.com/projects/yahp/en/stable/README.html>`_ interface for :class:`.BERTModel`.
Args:
pretrained_model_name (str): "Pretrained model name to pull from Huggingface Model Hub."
model_config (Dict[str, JSON]): A dictionary providing a HuggingFace model configuration.
tokenizer_name (str): The tokenizer used for this model,
necessary to assert required model inputs.
use_pretrained (bool, optional): Whether to initialize the model with the pretrained weights.
gradient_checkpointing (bool, optional): Use gradient checkpointing. default: False.
"""
def initialize_object(self) -> "BERTModel":
try:
import transformers
except ImportError as e:
raise MissingConditionalImportError(extra_deps_group="nlp", conda_package="transformers") from e
from composer.models.bert.model import BERTModel
self.validate()
if self.model_config:
config = transformers.BertConfig.from_dict(self.model_config)
elif self.pretrained_model_name is not None:
config = transformers.BertConfig.from_pretrained(self.pretrained_model_name)
else:
raise ValueError('One of pretrained_model_name or model_config needed.')
# set the number of labels ot the vocab size, used for measuring MLM accuracy
config.num_labels = config.vocab_size
# setup the tokenizer in the hparams interface
if self.tokenizer_name is not None:
tokenizer = transformers.BertTokenizer.from_pretrained(self.tokenizer_name)
else:
tokenizer = None
if self.use_pretrained:
# TODO (Moin): handle the warnings on not using the seq_relationship head
model = transformers.AutoModelForMaskedLM.from_pretrained(self.pretrained_model_name)
else:
model = transformers.AutoModelForMaskedLM.from_config(config) #type: ignore (thirdparty)
return BERTModel(
module=model,
config=config, #type: ignore (thirdparty)
tokenizer=tokenizer,
)