Tip

This tutorial is available as a Jupyter notebook.

Open in Colab

๐Ÿค— Finetuning Hugging Face Models#

Want to use Hugging Face models with Composer? No problem. Here, weโ€™ll walk through using Composer to fine-tune a pretrained Hugging Face BERT model.

Tutorial Goals and Concepts Covered#

The goal of this tutorial is to demonstrate how to fine-tune a pretrained Hugging Face transformer using the Composer library!

We will focus on fine-tuning a pretrained BERT-base model on the Stanford Sentiment Treebank v2 (SST-2) dataset. After fine-tuning, the BERT model should be able to determine if a sentence has positive or negative sentiment.

Along the way, we will touch on:

  • Creating our Hugging Face BERT model, tokenizer, and data loaders

  • Wrapping the Hugging Face model as a ComposerModel for use with the Composer trainer

  • Training with Composer

  • Visualization examples

Letโ€™s do this ๐Ÿš€

Install Composer#

To use Hugging Face with Composer, weโ€™ll need to install Composer with the NLP dependencies. If you havenโ€™t already, run:

[ ]:
%pip install 'mosaicml[nlp, tensorboard]'
# To install from source instead of the last release, comment the command above and uncomment the following one.
# %pip install 'mosaicml[nlp, tensorboard] @ git+https://github.com/mosaicml/composer.git'"

Import Hugging Face Pretrained Model#

First, we import a pretrained BERT model (specifically, BERT-base for uncased text) and its associated tokenizer from the transformers library.

Sentiment classification has two labels, so we set num_labels=2 when creating our model.

[ ]:
import transformers

# Create a BERT sequence classification model using Hugging Face transformers
model = transformers.AutoModelForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
tokenizer = transformers.AutoTokenizer.from_pretrained('bert-base-uncased')

Creating Dataloaders#

Next, we will download and tokenize the SST-2 datasets.

[ ]:
import datasets
import os
from multiprocessing import cpu_count

# Create BERT tokenizer
def tokenize_function(sample):
    return tokenizer(
        text=sample['sentence'],
        padding="max_length",
        max_length=256,
        truncation=True
    )

# Tokenize SST-2
sst2_dataset = datasets.load_dataset("glue", "sst2", num_proc=os.cpu_count() - 1)
tokenized_sst2_dataset = sst2_dataset.map(tokenize_function,
                                          batched=True,
                                          num_proc=cpu_count(),
                                          batch_size=100,
                                          remove_columns=['idx', 'sentence'])

# Split dataset into train and validation sets
train_dataset = tokenized_sst2_dataset["train"]
eval_dataset = tokenized_sst2_dataset["validation"]

Here, we will create a PyTorch DataLoader for each of the datasets generated in the previous block.

[ ]:
from torch.utils.data import DataLoader
data_collator = transformers.data.data_collator.default_data_collator
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=False, drop_last=False, collate_fn=data_collator)
eval_dataloader = DataLoader(eval_dataset,batch_size=16, shuffle=False, drop_last=False, collate_fn=data_collator)

Convert model to ComposerModel#

Composer uses HuggingFaceModel as a convenient interface for wrapping a Hugging Face model (such as the one we created above) in a ComposerModel. Its parameters are:

  • model: The Hugging Face model to wrap.

  • tokenizer: The Hugging Face tokenizer used to create the input data

  • metrics: A list of torchmetrics to apply to the output of eval_forward (a ComposerModel method).

  • use_logits: A boolean which, if True, flags that the modelโ€™s output logits should be used to calculate validation metrics.

See the API Reference for additional details.

[ ]:
from torchmetrics.classification import MulticlassAccuracy
from composer.models.huggingface import HuggingFaceModel
from composer.metrics import CrossEntropy

metrics = [CrossEntropy(), MulticlassAccuracy(num_classes=2, average='micro')]
# Package as a trainer-friendly Composer model
composer_model = HuggingFaceModel(model, tokenizer=tokenizer, metrics=metrics, use_logits=True)

Optimizers and Learning Rate Schedulers#

The last setup step is to create an optimizer and a learning rate scheduler. We will use PyTorchโ€™s AdamW optimizer and linear learning rate scheduler since these are typically used to fine-tune BERT on tasks such as SST-2.

[ ]:
from torch.optim import AdamW
from torch.optim.lr_scheduler import LinearLR

optimizer = AdamW(
    params=composer_model.parameters(),
    lr=3e-5, betas=(0.9, 0.98),
    eps=1e-6, weight_decay=3e-6
)
linear_lr_decay = LinearLR(
    optimizer, start_factor=1.0,
    end_factor=0, total_iters=150
)

Composer Trainer#

We will now specify a Composer Trainer object and run our training! Trainer has many arguments that are described in our documentation, so weโ€™ll discuss only the less-obvious arguments used below:

  • max_duration - a string specifying how long to train. This can be in terms of batches (e.g., '10ba' is 10 batches) or epochs (e.g., '1ep' is 1 epoch), among other options.

  • schedulers - a (list of) PyTorch or Composer learning rate scheduler(s) that will be composed together.

  • device - specifies if the training will be done on CPU or GPU by using 'cpu' or 'gpu', respectively. You can omit this to automatically train on GPUs if theyโ€™re available and fall back to the CPU if not.

  • train_subset_num_batches - specifies the number of training batches to use for each epoch. This is not a necessary argument but is useful for quickly testing code.

  • precision - whether to do the training in full precision ('fp32') or mixed precision ('amp'). Mixed precision can provide a ~2x training speedup on recent NVIDIA GPUs.

  • seed - sets the random seed for the training run, so the results are reproducible!

[ ]:
import torch
from composer import Trainer

# Create Trainer Object
trainer = Trainer(
    model=composer_model, # This is the model from the HuggingFaceModel wrapper class.
    train_dataloader=train_dataloader,
    eval_dataloader=eval_dataloader,
    max_duration="1ep",
    optimizers=optimizer,
    schedulers=[linear_lr_decay],
    device='gpu' if torch.cuda.is_available() else 'cpu',
    train_subset_num_batches=150,
    precision='fp32',
    seed=17
)
# Start training
trainer.fit()

Visualizing Results#

To check the trainingโ€™s validation accuracy, we read the Trainer object state.eval_metrics

[ ]:
trainer.state.eval_metrics

Our model reaches ~86% accuracy with only 150 iterations of training! Letโ€™s visualize a few samples from the validation set to see how our model performs.

[ ]:
eval_batch = next(iter(eval_dataloader))

# Move batch to gpu
eval_batch = {k: v.cuda() if torch.cuda.is_available() else v for k, v in eval_batch.items()}
with torch.no_grad():
    predictions = composer_model(eval_batch)["logits"].argmax(dim=1)

# Visualize only 5 samples
predictions = predictions[:5]

label = ['negative', 'positive']
for i, prediction in enumerate(predictions):
    sentence = sst2_dataset["validation"][i]["sentence"]
    correct_label = label[sst2_dataset["validation"][i]["label"]]
    prediction_label = label[prediction]
    print(f"Sample: {sentence}")
    print(f"Label: {correct_label}")
    print(f"Prediction: {prediction_label}")
    print()

Save Fine-Tuned Model#

Finally, to save the fine-tuned model parameters we call the PyTorch save method and pass it the modelโ€™s state_dict:

[ ]:
torch.save(trainer.state.model.state_dict(), 'model.pt')

What next?#

Youโ€™ve now seen how to use the Composer Trainer to fine-tune a pre-trained Hugging Face BERT on a subset of the SST-2 dataset.

If you want to keep learning more, try looking through some of the documents linked throughout this tutorial to see if you can form a deeper intuition for whatโ€™s going on in these examples.

In addition, please continue to explore our tutorials and examples! Here are a couple suggestions:

  • Explore domain-specific pretraining of a Hugging Face model in a second Hugging Face + Composer tutorial.

  • Explore more advanced applications of Composer like applying image segmentation to medical images.

  • Learn about callbacks and how to apply early stopping.

  • Check out the examples repo for full examples of training large language models like GPT and BERT, image segmentation models like DeepLab, and more!

Come get involved with MosaicML!#

Weโ€™d love for you to get involved with the MosaicML community in any of these ways:

Star Composer on GitHub#

Help make others aware of our work by starring Composer on GitHub.

Join the MosaicML Slack#

Head on over to the MosaicML slack to join other ML efficiency enthusiasts. Come for the paper discussions, stay for the memes!

Contribute to Composer#

Is there a bug you noticed or a feature youโ€™d like? File an issue or make a pull request!