Source code for composer.models.gpt2.gpt2_hparams

# Copyright 2021 MosaicML. All Rights Reserved.

"""`YAHP <https://docs.mosaicml.com/projects/yahp/en/stable/README.html>`_ interface for :class:`.GPT2Model`."""

import dataclasses
from typing import TYPE_CHECKING

from composer.models.transformer_hparams import TransformerHparams
from composer.utils.import_helpers import MissingConditionalImportError

if TYPE_CHECKING:
    from composer.models.transformer_shared import ComposerTransformer

__all__ = ["GPT2Hparams"]


[docs]@dataclasses.dataclass class GPT2Hparams(TransformerHparams): """`YAHP <https://docs.mosaicml.com/projects/yahp/en/stable/README.html>`_ interface for :class:`.GPT2Model`. Args: pretrained_model_name (str): Pretrained model name to pull from Hugging Face Model Hub. model_config (Dict[str, JSON]): A dictionary providing a HuggingFace model configuration. tokenizer_name (Optional[str]): The tokenizer used for this model, necessary to assert required model inputs. use_pretrained (bool, optional): Whether to initialize the model with the pretrained weights. Default: ``False``. gradient_checkpointing (bool, optional): Use gradient checkpointing. Default: ``False``. """ def initialize_object(self) -> "ComposerTransformer": try: import transformers except ImportError as e: raise MissingConditionalImportError(extra_deps_group="nlp", conda_package="transformers") from e from composer.models.gpt2.model import GPT2Model self.validate() if self.model_config: config = transformers.GPT2Config.from_dict(self.model_config) elif self.pretrained_model_name is not None: # TODO (Moin): verify that the config is an appropriate instance of GPT2! config = transformers.GPT2Config.from_pretrained(self.pretrained_model_name) else: raise ValueError('One of pretrained_model_name or model_config needed.') # setup the tokenizer in the hparams interface if self.tokenizer_name is not None: tokenizer = transformers.GPT2Tokenizer.from_pretrained(self.tokenizer_name) else: tokenizer = None if self.use_pretrained: model = transformers.AutoModelForCausalLM.from_pretrained(self.pretrained_model_name) else: model = transformers.AutoModelForCausalLM.from_config(config) #type: ignore (thirdparty) return GPT2Model( module=model, config=config, #type: ignore (thirdparty) tokenizer=tokenizer, gradient_checkpointing=self.gradient_checkpointing, )