create_bert_classification#
- composer.models.create_bert_classification(num_labels=2, use_pretrained=False, pretrained_model_name=None, model_config=None, tokenizer_name=None, gradient_checkpointing=False)[source]#
- BERT classification model based on ๐ค Transformers. - For more information, see Transformers. - Args:
- num_labels (int, optional): The number of classes in the classification task. Default: - 2. gradient_checkpointing (bool, optional): Use gradient checkpointing. Default:- False. use_pretrained (bool, optional): Whether to initialize the model with the pretrained weights. Default:- False. model_config (dict): The settings used to create a Hugging Face BertConfig. BertConfig is used to specify the architecture of a Hugging Face model. tokenizer_name (str, optional): Tokenizer name used to preprocess the dataset and validate the models inputs.- { "_name_or_path": "bert-base-uncased", "architectures": [ "BertForSequenceClassification ], "attention_probs_dropout_prob": 0.1, "classifier_dropout": null, "gradient_checkpointing": false, "hidden_act": "gelu", "hidden_dropout_prob": 0.1, "hidden_size": 768, "id2label": { "0": "LABEL_0", "1": "LABEL_1", "2": "LABEL_2" }, "initializer_range": 0.02, "intermediate_size": 3072, "label2id": { "LABEL_0": 0, "LABEL_1": 1, "LABEL_2": 2 }, "layer_norm_eps": 1e-12, "max_position_embeddings": 512, "model_type": "bert", "num_attention_heads": 12, "num_hidden_layers": 12, "pad_token_id": 0, "position_embedding_type": "absolute", "transformers_version": "4.16.0", "type_vocab_size": 2, "use_cache": true, "vocab_size": 30522 } 
 - To create a BERT model for classification: - from composer.models import create_bert_classification model = create_bert_classification(num_labels=3) # if the task has three classes. - Note:
- This function can be used to construct a BERT model for regression by setting - num_labels == 1. This will have two noteworthy effects. First, it will switch the training loss to- MSELoss. Second, the returned- ComposerModelโs train/validation metrics will be- MeanSquaredErrorand- SpearmanCorrCoef.- For the classification case (when - num_labels > 1), the training loss is- CrossEntropyLoss, and the train/validation metrics are- MulticlassAccuracyand- MatthewsCorrCoef, as well as- BinaryF1Scoreif- num_labels == 2.