๐ค Pretraining and Finetuning with Hugging Face Models#
Want to pretrain and finetune a Hugging Face model with Composer? No problem. Here, weโll walk through using Composer to pretrain and finetune a Hugging Face model.
Recommended Background#
If you have already gone through our tutorial on finetuning a pretrained Hugging Face model with Composer, many parts of this tutorial will be familiar to you, but it is not necessary to do that one first.
This tutorial assumes you are familiar with transformer models for NLP and with Hugging Face.
To better understand the Composer part, make sure youโre comfortable with the material in our Getting Started tutorial.
Tutorial Goals and Concepts Covered#
The goal of this tutorial is to demonstrate how to pretrain and finetune a Hugging Face transformer using the Composer library!
Inspired by this paper showing that performing unsupervised pretraining on the downstream dataset can be surprisingly effective, we will focus on pretraining and finetuning a small version of Electra on the AG News dataset!
Along the way, we will touch on:
- Creating our Hugging Face model, tokenizer, and data loaders 
- Wrapping the Hugging Face model as a - ComposerModelfor use with the Composer trainer
- Reloading the pretrained model with a new head for sequence classification 
- Training with Composer 
Letโs do this ๐
Install Composer#
To use Hugging Face with Composer, weโll need to install Composer with the NLP dependencies. If you havenโt already, run:
[ ]:
%pip install 'mosaicml[nlp]'
Import Hugging Face Model#
First, we import an Electra model and its associated tokenizer from the transformers library. We use Electra small in this notebook so that our model trains quickly.
[ ]:
import transformers
from composer.utils import reproducibility
# Create an Electra masked language modeling model using Hugging Face transformers
# Note: this is just loading the model architecture, and is using randomly initialized weights, so it is important to set
# the random seed here
reproducibility.seed_all(17)
config = transformers.AutoConfig.from_pretrained('google/electra-small-discriminator')
model = transformers.AutoModelForMaskedLM.from_config(config)
tokenizer = transformers.AutoTokenizer.from_pretrained('google/electra-small-discriminator')
Creating Dataloaders#
For the purpose of this tutorial, we are going to perform unsupervised pretraining (masked language modeling) on our downstream dataset, AG News. We are only going to train for one epoch here, but note that the paper that showed good performance from pretraining on the downstream dataset trained for much longer.
[ ]:
import datasets
from torch.utils.data import DataLoader
# Load the AG News dataset from Hugging Face
agnews_dataset = datasets.load_dataset('ag_news')
# Split the dataset randomly into a train and eval set
split_dict = agnews_dataset['train'].train_test_split(test_size=0.2, shuffle=True, seed=17)
train_dataset = split_dict['train']
eval_dataset = split_dict['test']
text_column_name = 'text'
# Tokenize the datasets
def tokenize_function(examples):
    # Remove empty lines
    examples[text_column_name] = [
        line for line in examples[text_column_name] if len(line) > 0 and not line.isspace()
    ]
    return tokenizer(
        examples[text_column_name],
        padding='max_length',
        truncation=True,
        max_length=256,
        return_special_tokens_mask=True,
    )
tokenized_train = train_dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=[text_column_name, 'label'],
    load_from_cache_file=False,
)
tokenized_eval = eval_dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=[text_column_name, 'label'],
    load_from_cache_file=False,
)
# We use the language modeling data collator from Hugging Face which will handle preparing the inputs correctly
# for masked language modeling
collator = transformers.DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)
# Create the dataloaders
train_dataloader = DataLoader(tokenized_train, batch_size=64, collate_fn=collator)
eval_dataloader = DataLoader(tokenized_eval, batch_size=64, collate_fn=collator)
Convert model to ComposerModel#
Composer uses HuggingFaceModel as a convenient interface for wrapping a Hugging Face model (such as the one we created above) in a ComposerModel. Its parameters are:
- model: The Hugging Face model to wrap.
- tokenizer: The Hugging Face tokenizer used to create the input data
- metrics: A list of torchmetrics to apply to the output of- eval_forward(a- ComposerModelmethod).
- use_logits: A boolean which, if True, flags that the modelโs output logits should be used to calculate validation metrics.
See the API Reference for additional details.
[ ]:
from composer.metrics.nlp import LanguageCrossEntropy, MaskedAccuracy
from composer.models.huggingface import HuggingFaceModel
metrics = [
    LanguageCrossEntropy(ignore_index=-100),
    MaskedAccuracy(ignore_index=-100)
]
# Package as a trainer-friendly Composer model
composer_model = HuggingFaceModel(model, tokenizer=tokenizer, metrics=metrics, use_logits=True)
Optimizers and Learning Rate Schedulers#
The last setup step is to create an optimizer and a learning rate scheduler. We will use Composerโs DecoupledAdamW optimizer and LinearWithWarmupScheduler.
[ ]:
from composer.optim import DecoupledAdamW, LinearWithWarmupScheduler
optimizer = DecoupledAdamW(composer_model.parameters(), lr=1.0e-4, betas=[0.9, 0.98], eps=1.0e-06, weight_decay=1.0e-5)
lr_scheduler = LinearWithWarmupScheduler(t_warmup='250ba', alpha_f=0.02)
Composer Trainer#
We will now specify a Composer Trainer object and run our training! Trainer has many arguments that are described in our documentation, so weโll discuss only the less-obvious arguments used below:
- max_duration- a string specifying how long to train. This can be in terms of batches (e.g.,- '10ba'is 10 batches) or epochs (e.g.,- '1ep'is 1 epoch), among other options.
- save_folder- a string specifying where to save checkpoints to
- schedulers- a (list of) PyTorch or Composer learning rate scheduler(s) that will be composed together.
- device- specifies if the training will be done on CPU or GPU by using- 'cpu'or- 'gpu', respectively. You can omit this to automatically train on GPUs if theyโre available and fall back to the CPU if not.
- train_subset_num_batches- specifies the number of training batches to use for each epoch. This is not a necessary argument but is useful for quickly testing code.
- precision- whether to do the training in full precision (- 'fp32') or mixed precision (- 'amp_fp16'or- 'amp_bf16'). Mixed precision can provide a ~2x training speedup on recent NVIDIA GPUs.
- seed- sets the random seed for the training run, so the results are reproducible!
[ ]:
import torch
from composer import Trainer
# Create Trainer Object
trainer = Trainer(
    model=composer_model, # This is the model from the HuggingFaceModel wrapper class.
    train_dataloader=train_dataloader,
    eval_dataloader=eval_dataloader,
    max_duration='1ep', # train for more epochs to get better performance
    save_folder='checkpoints/pretraining/',
    optimizers=optimizer,
    schedulers=[lr_scheduler],
    device='gpu' if torch.cuda.is_available() else 'cpu',
    #train_subset_num_batches=100, # uncomment this line to only run part of training, which will be faster
    precision='fp32',
    seed=17,
)
# Start training
trainer.fit()
trainer.close()
Loading the pretrained model for finetuning#
Now that we have a pretrained Hugging Face model, we will load it in and finetune it on a sequence classification task. Composer provides utilities to easily reload a Hugging Face model and tokenizer from a composer checkpoint, and add a task specific head to the model so that it can be finetuned for a new task
[ ]:
from torchmetrics.classification import MulticlassAccuracy
from composer.metrics import CrossEntropy
from composer.models import HuggingFaceModel
# Note: this does not load the weights, just the right model/tokenizer class and config.
# The weights will be loaded by the Composer trainer
model, tokenizer = HuggingFaceModel.hf_from_composer_checkpoint(
    f'checkpoints/pretraining/latest-rank0.pt',
    model_instantiation_class='transformers.AutoModelForSequenceClassification',
    model_config_kwargs={'num_labels': 4})
metrics = [CrossEntropy(), MulticlassAccuracy(num_classes=4, average='micro')]
composer_model = HuggingFaceModel(model, tokenizer=tokenizer, metrics=metrics, use_logits=True)
The next part should look very familiar if you have already gone through the tutorial, as it is exactly the same except using a different dataset and starting model!
We will now finetune on the AG News dataset. We have already downloaded and split the dataset, so now we just need to prepare the dataset for finetuning.
[ ]:
import datasets
text_column_name = 'text'
def tokenize_function(sample):
    return tokenizer(
        text=sample[text_column_name],
        padding="max_length",
        max_length=256,
        truncation=True
    )
tokenized_train = train_dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=['text'],
    load_from_cache_file=False,
)
tokenized_eval = eval_dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=['text'],
    load_from_cache_file=False,
)
from torch.utils.data import DataLoader
data_collator = transformers.data.data_collator.default_data_collator
train_dataloader = DataLoader(tokenized_train, batch_size=32, shuffle=False, drop_last=False, collate_fn=data_collator)
eval_dataloader = DataLoader(tokenized_eval, batch_size=32, shuffle=False, drop_last=False, collate_fn=data_collator)
Next we will create our optimizer and learning rate scheduler for the finetuning task.
[ ]:
from composer.optim import DecoupledAdamW, LinearWithWarmupScheduler
optimizer = DecoupledAdamW(composer_model.parameters(), lr=1.0e-4, betas=[0.9, 0.98], eps=1.0e-06, weight_decay=3.0e-4)
lr_scheduler = LinearWithWarmupScheduler(t_warmup='0.06dur', alpha_f=0.02)
Lastly we can make our finetuning trainer and train! The only new arguments to the trainer here are load_path, which tells Composer where to load the already trained weights from, and load_weights_only, which tells Composer that we only want to load the weights from the checkpoint, not any other state from the previous training run.
[ ]:
import torch
from composer import Trainer
# Create Trainer Object
trainer = Trainer(
    model=composer_model, # This is the model from the HuggingFaceModel wrapper class.
    train_dataloader=train_dataloader,
    eval_dataloader=eval_dataloader,
    max_duration='1ep', # Again, training for more epochs is likely to lead to higher performance
    save_folder='checkpoints/finetuning/',
    load_path=f'checkpoints/pretraining/latest-rank0.pt',
    load_weights_only=True, # We're starting a new training run, so we just the model weights
    optimizers=optimizer,
    schedulers=[lr_scheduler],
    device='gpu' if torch.cuda.is_available() else 'cpu',
    precision='fp32',
    seed=17,
)
# Start training
trainer.fit()
trainer.close()
Not bad, we got up to 91.5% accuracy on our eval split! Note that this is considerably less than the state-of-the-art on this task, but we started from a randomly initialized model, and did not train for very long, either in pretraining or finetuning!
There are many possibilities for how to improve performance. Using a larger model and training for longer are often the first thing to try to improve performance (given a fixed dataset). You can also tweak the hyperparameters, try a different model class, start from pretrained weights instead of randomly initialized, or try adding some of Composerโs algorithms. We encourage you to play around with these and other things to get familiar with training in Composer.
What next?#
Youโve now seen how to use the Composer Trainer to pretrain and finetune a Hugging Face model on the AG News dataset.
If you want to keep learning more, try looking through some of the documents linked throughout this tutorial to see if you can form a deeper intuition for whatโs going on in these examples.
In addition, please continue to explore our tutorials and examples! Here are a couple suggestions:
- Explore more advanced applications of Composer like applying image segmentation to medical images. 
- Learn about callbacks and how to apply early stopping. 
- Check out the benchmarks repo for full examples of training large language models like GPT and BERT, image segmentation models like DeepLab, and more! 
Come get involved with MosaicML!#
Weโd love for you to get involved with the MosaicML community in any of these ways:
Star Composer on GitHub#
Help make others aware of our work by starring Composer on GitHub.
Join the MosaicML Slack#
Head on over to the MosaicML slack to join other ML efficiency enthusiasts. Come for the paper discussions, stay for the memes!
Contribute to Composer#
Is there a bug you noticed or a feature youโd like? File an issue or make a pull request!