โœ… Checkpointing#

Composer can be configured to automatically save training checkpoints by passing the argument save_folder when creating the Trainer.

To customize the filenames of checkpoints inside save_folder, you can set the save_filename argument. By default, checkpoints will be named like 'ep{epoch}-ba{batch}-rank{rank}' within the save_folder.

In addition, the trainer creates a symlink called 'latest-rank{rank}', which points to the latest saved checkpoint file. You can customize this symlink name by setting the save_latest_filename argument.

The save_folder, save_filename, and save_latest arguments are Python format strings, so you can customize the folder structure to include information such as the rank of the Python process or the current training progress. Please see the CheckpointSaver for the full list of available format variables.

For example:

from composer import Trainer

trainer = Trainer(
    model=model,
    train_dataloader=train_dataloader,
    max_duration="2ep",
    save_folder="./path/to/checkpoints",
    save_filename="ep{epoch}",
    save_latest_filename="latest",
    save_overwrite=True,
)

trainer.fit()

Save Interval#

By default, checkpoints are saved every epoch, but this interval can be configured using the save_interval argument. The save_interval can be an integer (interpreted as a number of epochs), a time string (see the Time Guide for more information), or a function that takes (State, Event) and returns whether a checkpoint should be saved.

For example:

  • save_interval=1 to save every epoch (the default).

  • save_interval="10ep" to save every 10 epochs.

  • save_interval="500ba" to save every 500 batches/steps.

  • save_interval=lambda state, event: state.timestamp.epoch > 50 and event == Event.EPOCH_CHECKPOINT to save every epoch, starting after the 50th epoch.

Putting this together, hereโ€™s how to save checkpoints:

from composer import Trainer

trainer = Trainer(
    model=model,
    train_dataloader=train_dataloader,
    max_duration="1ep",
    save_filename="ep{epoch}.pt",
    save_folder="./path/to/checkpoints",
    save_overwrite=True,
    save_interval="1ep",  # Save checkpoints every epoch
)
trainer.fit()

The above code will train a model for 1 epoch, and then save the checkpoint.

Anatomy of a Checkpoint#

The above code, when run, will produce the checkpoints below:

>>> trainer.saved_checkpoints
['./path/to/checkpoints/ep1.pt']
>>> latest_checkpoint = trainer.saved_checkpoints[-1]
>>> state_dict = torch.load(latest_checkpoint)
>>> list(state_dict)
['state', 'rng']
>>> list(state_dict['state'].keys())
['model', 'optimizers', 'schedulers', 'algorithms', 'callbacks', 'scaler', 'timestamp', 'rank_zero_seed', 'run_name', 'dataset_state', 'integrations', 'metadata']

Resume training#

To resume training from a previous checkpoint, set the load_path argument of the Trainer to the checkpoint filepath. When the Trainer is initialized, the checkpoint state will be restored, and the Trainer.fit() will continue training from where the checkpoint left off.

trainer = Trainer(
    model=model,
    train_dataloader=train_dataloader,
    max_duration="90ep",
    save_overwrite=True,
    load_path="./path/to/checkpoints/ep25.pt",
)
trainer.fit()

The above code will load the checkpoint from epoch 25 and continue training for another 65 epochs (to reach 90 epochs total).

Different model or optimizer objects that are passed into the trainer when it is resumed will be respected. However, an error will be raised if the weights or state from the checkpoint are not compatible with these new objects.

Note

Only the attributes in State.serialized_attributes are serialized and loaded. By default, they are:

Attribute

Description

model

The model under training.

optimizers

The optimizers being used to train the model.

schedulers

The learning rate schedulers.

algorithms

The algorithms used for training.

callbacks

The callbacks used for training.

scaler

The gradient scaler in use for mixed precision training.

timestamp

The timestamp that tracks training loop progress.

rank_zero_seed

The seed of the rank zero process.

train_metrics

The current training metrics.

eval_metrics

The current validation metrics.

run_name

The run name for training.

dataset_state

The dataset iteration state.

All other trainer arguments (e.g. max_duration or precision) will use either the defaults or what is passed in when reconstructing the trainer.

Saving for Inference#

By default, the Trainer stores the entire training state in each checkpoint. If you would like to store only the model weights in a checkpoint (plus metadata and integrations), set save_weights_only=True. The metadata includes the information about the enviornment used to train the model and any integrations, such as HuggingFace, and those integrationsโ€™ metadata. . .. testcode:

from composer.trainer import Trainer

trainer = Trainer(
    ...,
    save_folder="checkpoints",
    save_weights_only=True,
    save_overwrite=True,
)

trainer.fit()

Saving Multiple Checkpoint Types#

To save multiple checkpoint types, such as full checkpoints and weights-only checkpoints, the CheckpointSaver can be passed directly into the callbacks argument of the trainer. Each CheckpointSaver can have its own save folder, interval, and other parameters.

When configuring checkpoints via the callbacks, it is not necessary to specify the save_folder or other checkpoint saving parameters directly on the trainer.

from composer.trainer import Trainer
from composer.callbacks import CheckpointSaver

trainer = Trainer(
    ...,
    callbacks=[
        CheckpointSaver(
            folder='full_checkpoints',
            save_interval='5ep',
            overwrite=True,
            num_checkpoints_to_keep=1,  # only keep the latest, full checkpoint
        ),
        CheckpointSaver(
            folder='weights_only_checkpoints',
            weights_only=True,
            overwrite=True,
        ),
    ],
)

trainer.fit()

Fine-tuning#

The Trainer will only load the model weights from the checkpoint if load_weights_only=True or if the checkpoint was saved with save_weights_only=True. This is especially useful for model fine-tuning, since the rest of the trainerโ€™s state no longer applies.

If the fine-tuned model contains different parameter names than the model in the checkpoint, set load_strict=False to ignore mismatches in model parameter names between the serialized model state and new model object. Parameters with the same name are expected to have the same shape and will have their state restored. Parameters with different names will ignored.

ft_trainer = Trainer(
    model=finetune_model,
    train_dataloader=finetune_dataloader,
    max_duration="10ep",
    load_path="./path/to/checkpoints/ep50.pt",
    load_weights_only=True,
    load_strict_model_weights=False,
)

ft_trainer.fit()

This example will load only the model weights from epoch 1 and then continue training on the fine-tuned dataloader for 10 epochs.

Loading Weights Externally#

The model weights are located at state_dict["state"]["model"] within the stored checkpoint. To load them into a model outside of a Trainer, use torch.load():

model = Model(num_channels, num_classes)
state_dict = torch.load("./path/to/checkpoints/ep1.pt")
model.load_state_dict(state_dict["state"]["model"])

Uploading Checkpoints to Object Store#

Checkpoints can also be saved to and loaded from your object store of choice (e.g. AWS S3 or Google Cloud Storage). Writing checkpoints to an object store is a two-step process. The checkpoints are first written to the local filesystem, and then the RemoteUploaderDownloader logger will upload checkpoints to the specified object store.

Behind the scenes, the RemoteUploaderDownloader uses Apache Libcloud.

The easiest way to upload checkpoints to S3, OCI, or GCS is to prefix your save_folder with s3://, oci://, or gs://. All other checkpoint arguments remain the same. For example, save_filename will be the name of the checkpoint file that gets uploaded to the S3, OCI, or GCS URI that you specified.

For example, for S3:

from composer.trainer import Trainer

# Save checkpoints every epoch to s3://my_bucket/checkpoints
trainer = Trainer(
    model=model,
    train_dataloader=train_dataloader,
    max_duration='10ep',
    save_folder='s3://my_bucket/checkpoints',
    save_interval='1ep',
    save_overwrite=True,
    save_filename='ep{epoch}.pt',
    save_num_checkpoints_to_keep=0,  # delete all checkpoints locally
)

trainer.fit()

This will train your model, saving the checkpoints locally, upload them to the S3 Bucket my_bucket, and delete the checkpoints from the local disk. The checkpoints will be located on S3 inside your bucket as checkpoints/ep3.pt for third epochโ€™s checkpoints, for example. The full URI in this case would be: s3://my_bucket/checkpoints/ep3.pt.

For uploading checkpoints to Coreweaveโ€™s object store, the code is very similar to the above S3 uploading code. The only difference is you must set your Coreweave endpoint url. To do this you can just set the S3_ENDPOINT_URL environment variable before creating the Trainer, like so:

import os

os.environ['S3_ENDPOINT_URL'] = 'https://object.las1.coreweave.com'
from composer.trainer import Trainer

# Save checkpoints every epoch to Coreweave object store: s3://my_bucket/checkpoints
trainer = Trainer(
    model=model,
    train_dataloader=train_dataloader,
    max_duration='10ep',
    save_folder='s3://my_bucket/checkpoints',
    save_interval='1ep',
    save_overwrite=True,
    save_filename='ep{epoch}.pt',
    save_num_checkpoints_to_keep=0,  # delete all checkpoints locally
)

trainer.fit()

Uploading checkpoints to Cloudflareโ€™s R2 object store, like Coreweave, is as easy as using the above S3 uploading code and then setting the S3_ENDPOINT_URL environment variable before creating the Trainer, like so:

import os

os.environ['S3_ENDPOINT_URL'] = 'https://MY_ACCOUNT_ID.r2.cloudflarestorage.com/'
from composer.trainer import Trainer

# Save checkpoints every epoch to Cloudflare R2: s3://my_bucket/checkpoints
trainer = Trainer(
    model=model,
    train_dataloader=train_dataloader,
    max_duration='10ep',
    save_folder='s3://my_bucket/checkpoints',
    save_interval='1ep',
    save_overwrite=True,
    save_filename='ep{epoch}.pt',
    save_num_checkpoints_to_keep=0,  # delete all checkpoints locally
)

trainer.fit()

Similarly for OCI:

from composer.trainer import Trainer

# Save checkpoints every epoch to oci://my_bucket/checkpoints
trainer = Trainer(
    model=model,
    train_dataloader=train_dataloader,
    max_duration='10ep',
    save_folder='oci://my_bucket/checkpoints',
    save_interval='1ep',
    save_overwrite=True,
    save_filename='ep{epoch}.pt',
    save_num_checkpoints_to_keep=0,  # delete all checkpoints locally
)

trainer.fit()

This will train your model, saving the checkpoints locally, upload them to the OCI Bucket my_bucket, and delete the checkpoints from the local disk. The checkpoints will be located on OCI inside your bucket as checkpoints/ep3.pt for third epochโ€™s checkpoints, for example. The full URI in this case would be: oci://my_bucket/checkpoints/ep3.pt.

Similarly for GCS:

from composer.trainer import Trainer

# Save checkpoints every epoch to gs://my_bucket/checkpoints
trainer = Trainer(
    model=model,
    train_dataloader=train_dataloader,
    max_duration='10ep',
    save_folder='gs://my_bucket/checkpoints',
    save_interval='1ep',
    save_overwrite=True,
    save_filename='ep{epoch}.pt',
    save_num_checkpoints_to_keep=0,  # delete all checkpoints locally
)

trainer.fit()

This will train your model, saving the checkpoints locally, upload them to the GCS Bucket my_bucket, and delete the checkpoints from the local disk. The checkpoints will be located on GCS inside your bucket as checkpoints/ep3.pt for third epochโ€™s checkpoints, for example. The full URI in this case would be: gs://my_bucket/checkpoints/ep3.pt. Note: For GCS, remember to input your HMAC access id and secret to the environment variables GCS_KEY and GCS_SECRET respectively or the save operation will fail.

There are a few additional trainer arguments which can be helpful to configure:

  • save_num_checkpoints_to_keep: Set this parameter to remove checkpoints from the local disk after they have been uploaded. For example, setting this parameter to 1 will only keep the latest checkpoint locally; setting it to 0 will remove each checkpoint after it has been uploaded. Checkpoints are never deleted from object stores.

  • save_remote_file_name: To customize how checkpoints are named in the cloud bucket, modify this parameter. By default, they will be named as '{run_name}/checkpoints/ep{epoch}-ba{batch}-rank{rank}'. See the CheckpointSaver documentation for the available format variables.

This is equivalent to creating a RemoteUploaderDownloader object and adding it to loggers. This a more involved operation, but is necessary for uploading checkpoints to other cloud object stores not supported by URI

from composer.loggers import RemoteUploaderDownloader
from composer.trainer import Trainer

remote_uploader_downloader = RemoteUploaderDownloader(
    bucket_uri="libcloud://checkpoint-debugging",
    backend_kwargs={
        "provider": "s3",  # The Apache Libcloud provider name
        "container": "checkpoint-debugging",  # The name of the cloud container (i.e. bucket) to use.
        "provider_kwargs": {  # The Apache Libcloud provider driver initialization arguments
            'key': 'provider_key',  # The cloud provider key.
            'secret': '*******',  # The cloud provider secret.
            # Any additional arguments required for the cloud provider.
        },
    },
)

trainer = Trainer(
    model=model,
    train_dataloader=train_dataloader,
    max_duration='10ep',
    save_folder='checkpoints',
    save_interval='1ep',
    save_overwrite=True,
    save_filename='ep{epoch}.pt',
    save_num_checkpoints_to_keep=0,  # delete all checkpoints locally
    loggers=[remote_uploader_downloader],
)

trainer.fit()

Loading Checkpoints from Object Store#

Checkpoints saved to an object store can also be loaded in the same way as files saved on disk. Provide the LibcloudObjectStore to the trainerโ€™s load_object_store argument (you can also provide the full RemoteUploaderDownloader object as well). The load_path argument should be the path to the checkpoint file within the container/bucket.

from composer.utils import LibcloudObjectStore
from composer.trainer import Trainer

object_store = LibcloudObjectStore(
    provider="s3",  # The Apache Libcloud provider name
    container="checkpoint-debugging",  # The name of the cloud container (i.e. bucket) to use.
    provider_kwargs={  # The Apache Libcloud provider driver initialization arguments
        'key': 'provider_key',  # The cloud provider key.
        'secret': '*******',  # The cloud provider secret.
        # Any additional arguments required for the cloud provider.
    },
)

new_trainer = Trainer(
    model=model,
    train_dataloader=train_dataloader,
    max_duration="10ep",
    load_path="checkpoints/ep1.pt",
    load_object_store=object_store,
)

new_trainer.fit()

An easier way to load checkpoints from S3, OCI, GCS specifically is to just use a URI starting with s3://, oci://, or gs://. If you use the S3, OCI, or GCS URI, it is not necessary to specify a load_object_store. Note, that for other object stores like WandB or LibCloud, you must still specify a load_object_store.

new_trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
max_duration="10ep",
load_path="s3://checkpoint-debugging/checkpoints/ep1.pt",
)

new_trainer.fit()

This will load the first epochโ€™s checkpoints from S3 and resume training in the second epoch.

Similarly for OCI:

new_trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
max_duration="10ep",
load_path="oci://checkpoint-debugging/checkpoints/ep1.pt",
)

new_trainer.fit()

This will load the first epochโ€™s checkpoints from OCI and resume training in the second epoch.

Similarly for GCS:

new_trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
max_duration="10ep",
load_path="gs://checkpoint-debugging/checkpoints/ep1.pt",
)

new_trainer.fit()

This will load the first epochโ€™s checkpoints from GCS and resume training in the second epoch. Note: For GCS, remember to input your HMAC access id and secret to the environment variables GCS_KEY and GCS_SECRET respectively or the save operation will fail.

Warning

Do not load checkpoints from untrusted sources as they may contain malicious code. Users should ensure the proper sanity checks are in place before loading checkpoints.

Saving and Loading Sharded Checkpoints#

Using FSDP, you can save and load sharded checkpoints with Composer. See these docs for more info and some examples.

API Reference#

  • RemoteUploaderDownloader for saving checkpoints to cloud storage.

  • Trainer for the trainer checkpoint arguments.

  • CheckpointSaver for the CheckpointSaver arguments.

  • LibcloudObjectStore for setting up libcloud-supported object stores.

  • composer.utils.checkpoint for the underlying utilities to manually save and load checkpoints.