Source code for composer.utils.checkpoint

# Copyright 2021 MosaicML. All Rights Reserved.

"""Utilities for working with training checkpoints."""

from __future__ import annotations

import contextlib
import logging
import os
import pathlib
import shutil
import tarfile
import tempfile
import textwrap
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple

import torch

from composer.utils import dist, reproducibility
from composer.utils.file_helpers import (FORMAT_NAME_WITH_DIST_AND_TIME_TABLE, GetFileNotFoundException,
                                         format_name_with_dist_and_time, get_file, is_tar)
from composer.utils.object_store import ObjectStore

if TYPE_CHECKING:
    from composer.core.state import State
    from composer.loggers.logger import Logger

log = logging.getLogger(__name__)

__all__ = ["load_checkpoint", "save_checkpoint"]

_COMPOSER_STATES_FILENAME = "composer_states.pt"
_DEEPSPEED_TAG = "deepspeed"  # always tag with the same, deterministic name. We'll rename the tarball to the appropriate name.


def _format_path_with_rank_zero(path: str) -> str:
    """Formats ``path`` with the rank zero values."""
    return path.format(
        rank=0,
        local_rank=0,
        node_rank=0,
    )


def _format_path_with_current_rank(path: str) -> str:
    """Formats ``path`` formatted with the current rank values."""
    return path.format(
        rank=dist.get_global_rank(),
        local_rank=dist.get_local_rank(),
        node_rank=dist.get_node_rank(),
    )


def _get_write_mode(name: str) -> str:
    """Get the write mode to use with :func:`tarfile.open`."""
    if name.endswith('.tar'):
        return 'w'
    if name.endswith(".tar.gz") or name.endswith(".tgz"):
        return "w:gz"
    if name.endswith(".tar.bz2"):
        return "w:bz2"
    if name.endswith(".tar.lzma"):
        return "w:xz"
    raise ValueError(f"{name} does not end with a valid tarfile extension.")


[docs]def load_checkpoint( path: str, state: State, object_store: Optional[ObjectStore] = None, load_weights_only: bool = False, strict_model_weights: bool = False, chunk_size: int = 1_048_576, progress_bar: bool = True, ): """Load a checkpoint from a local file, URI, or cloud object store into ``state``. Args: path (str): The path format string to an existing checkpoint file. It can be a path to a file on the local disk, a URL, or if ``object_store`` is set, the object name for a checkpoint in a cloud bucket. When using `Deepspeed ZeRO <https://www.deepspeed.ai/tutorials/zero/>`_, checkpoints are shareded by rank. Instead of hard-coding the rank in the ``path``, use the following format variables: +------------------------+-------------------------------------------------------+ | Variable | Description | +========================+=======================================================+ | ``{rank}`` | The global rank, as returned by | | | :func:`~.dist.get_global_rank`. | +------------------------+-------------------------------------------------------+ | ``{local_rank}`` | The local rank of the process, as returned by | | | :func:`~.dist.get_local_rank`. | +------------------------+-------------------------------------------------------+ | ``{node_rank}`` | The node rank, as returned by | | | :func:`~.dist.get_node_rank`. | +------------------------+-------------------------------------------------------+ For example, suppose that checkpoints are stored in the following structure: .. code-block:: my_model/ep1-rank0.tar my_model/ep1-rank1.tar my_model/ep1-rank2.tar ... Then, ``path`` should be set to ``my_model/ep1-rank{rank}.tar``, and all ranks will load the correct state. state (State): The :class:`~composer.core.state.State` to load the checkpoint into. object_store (ObjectStore, optional): If the ``path`` is in an object store (i.e. AWS S3 or Google Cloud Storage), an instance of :class:`~.ObjectStore` which will be used to retreive the checkpoint. Otherwise, if the checkpoint is a local filepath, set to ``None``. (default: ``None``) load_weights_only (bool, optional): Whether or not to only restore the model weights from the checkpoint without restoring the associated state. (default: ``False``) strict_model_weights (bool, optional): Whether or not to force that the checkpointed weights must exactly match the model weights. (default: ``False``) chunk_size (int, optional): Chunk size (in bytes) to use when downloading checkpoints. Ignored if the checkpoint is a local file path. (default: ``1_048_576`` bytes (1 MB)) progress_bar (bool, optional): Whether or not to show a progress bar when downloading checkpoints. Ignored if the checkpoint is a local file path. (default: ``True``) Returns: Optional[List[Dict[str, Any]]]: The RNG state dicts, indexed by global rank, if :attr:`load_weights_only` is not None. Otherwise, None. """ # download the checkpoint to the node-local folder tempdir_ctx = tempfile.TemporaryDirectory() if dist.get_local_rank() == 0 else contextlib.nullcontext(None) with tempdir_ctx as tempdir: try: node_checkpoint_folder = _get_node_checkpoint_download_folder(tempdir) composer_states_filepath, extracted_checkpoint_folder, extracted_rank_n = _download_checkpoint( path=path, node_checkpoint_folder=node_checkpoint_folder, object_store=object_store, chunk_size=chunk_size, progress_bar=progress_bar, ) rng_state_dicts = _restore_checkpoint( state, composer_states_filepath, extracted_rank_n, extracted_checkpoint_folder, load_weights_only=load_weights_only, strict_model_weights=strict_model_weights, ) finally: # Wait for all ranks to finish restoring the checkpoint before releasing the tempdir, since tempdir can # be a shared resource between nodes. dist.barrier() log.info("%s loaded from %s", "Model weights" if load_weights_only else "Trainer checkpoint", path) return rng_state_dicts
def _get_node_checkpoint_download_folder(path: Optional[str]) -> str: """Broadcasts the ``path`` from the LOCAL rank zero to all LOCAL ranks.""" local_rank_zero = dist.get_local_world_size() * dist.get_node_rank() paths = dist.all_gather_object(path) local_rank_zero_path = paths[local_rank_zero] assert local_rank_zero_path is not None, "local rank zero provides the path" return local_rank_zero_path def _download_checkpoint( path: str, node_checkpoint_folder: str, object_store: Optional[ObjectStore], chunk_size: int, progress_bar: bool, ) -> Tuple[str, Optional[str], bool]: """Download the checkpoint stored at ``path``, potentially in ``object_store``, to ``node_checkpoint_folder``. Returns a tuple of (``composer_states_filepath``, ``extracted_checkpoint_folder``, ``extracted_rank_n``). * The ``composer_states_filepath``, is the path to the composer states, which can be passed into :meth:`torch.load`. * The ``extracted_checkpoint_folder`` is the path to the checkpoint folder, which can be passed into :meth:`deepspeed.DeepSpeedEngine.load_checkpoint`. * The ``extracted_rank_n`` is a boolean flag indicating whether a tarball was extracted on global rank greater than 0. """ rank_zero_checkpoint_filepath = os.path.join(node_checkpoint_folder, "rank0_checkpoint") rank_n_checkpoint_filepath = os.path.join(node_checkpoint_folder, f"rank{dist.get_global_rank()}_checkpoint") extracted_checkpoint_folder = None extracted_rank_n = False if is_tar(path): extracted_checkpoint_folder = os.path.join(node_checkpoint_folder, "checkpoint") composer_states_filepath = os.path.join(extracted_checkpoint_folder, _COMPOSER_STATES_FILENAME) else: # it's not an archive; it's just the composer state dict # and only rank zero has this file extracted_checkpoint_folder = None composer_states_filepath = rank_zero_checkpoint_filepath try: if dist.get_local_rank() == 0: # every NODE needs the GLOBAL rank zero checkpoint path = _format_path_with_rank_zero(path) get_file(destination=rank_zero_checkpoint_filepath, path=path, object_store=object_store, chunk_size=chunk_size, progress_bar=progress_bar) if extracted_checkpoint_folder is not None: try: with tarfile.open(rank_zero_checkpoint_filepath) as tarball: tarball.extractall(extracted_checkpoint_folder) except FileNotFoundError: # Not re-raising the file-not-found error as that is irrelevant; # the underlying issue is that the checkpoint file does not exist on the disk # or could not be downloaded raise RuntimeError(f"Checkpoint {path} does not exist") if rank_zero_checkpoint_filepath != rank_n_checkpoint_filepath: # every RANK needs ITS OWN checkpoint. # But, the global rank zero is a special case -- these files are the same! assert dist.get_global_rank() != 0, "invariant violation" try: get_file(destination=rank_n_checkpoint_filepath, path=_format_path_with_current_rank(path), object_store=object_store, chunk_size=chunk_size, progress_bar=progress_bar) except GetFileNotFoundException: # Allowing not-found errors to be ignored as sometimes there won't be rank-local checkpoints # (e.g. when not using deepspeed) pass if extracted_checkpoint_folder is not None: try: # it's an archive and needs to be extracted with tarfile.open(rank_n_checkpoint_filepath) as tarball: tarball.extractall(extracted_checkpoint_folder) extracted_rank_n = True except FileNotFoundError: # this will happen most of the time (i.e. whenever deepspeed # is not being used) so not logging anything pass finally: # Wait for all checkpoints on the node to finish downloading # Putting the barrier in a finally so the rank will always block on the barrier, # even if it has an exception. # Any exception will be re-raised after the barrier passes. The launcher script # will detect the process crash and terminate the other ranks dist.barrier() return composer_states_filepath, extracted_checkpoint_folder, extracted_rank_n def _restore_checkpoint( state: State, composer_states_filepath: str, extracted_rank_n: bool, extracted_checkpoint_folder: Optional[str], load_weights_only: bool, strict_model_weights: bool, ) -> Optional[List[Dict[str, Any]]]: """Restore a checkpoint into ``state`` and returns the rng state dicts (if ``load_weights_only`` is False).""" # Now, all ranks load the checkpoint that local rank zero downloaded state_dict = torch.load(composer_states_filepath, map_location='cpu') log.debug(f"Loaded checkpoint with keys {state_dict.keys()} and state keys {state_dict['state'].keys()}") if state.is_model_deepspeed: if extracted_checkpoint_folder is None: raise RuntimeError("Deepspeed checkpoints require a tarball, not a weights file.") global_rank = dist.get_global_rank() if global_rank > 0 and not extracted_rank_n: raise RuntimeError(f"Deepspeed checkpoint missing for rank {global_rank}") load_path, _ = state.deepspeed_model.load_checkpoint( extracted_checkpoint_folder, tag=_DEEPSPEED_TAG, load_module_only=load_weights_only, load_module_strict=strict_model_weights, ) if load_path is None: raise RuntimeError(f"Failed to load DeepSpeed checkpoint") elif load_weights_only: state.load_model_state(state_dict['state'], strict=strict_model_weights) if not load_weights_only: state.load_state_dict(state_dict['state']) return state_dict['rng']
[docs]def save_checkpoint(state: State, logger: Logger, filename: str = "ep{epoch}-ba{batch}-rank{rank}", *, weights_only: bool = False) -> List[pathlib.Path]: state_dict = { 'state': state.state_dict(), 'rng': reproducibility.get_rng_state(), } if weights_only and not state.is_model_deepspeed: state_dict['state'] = {"model": state_dict['state']['model']} checkpoint_filepath = format_name_with_dist_and_time(filename, logger.run_name, state.timer.get_timestamp()) if state.is_model_deepspeed and not is_tar(checkpoint_filepath): # Deepspeed requires tarballs; appending `.tar` checkpoint_filepath += ".tar" with tempfile.TemporaryDirectory() as tmpdir: composer_states_filepath = os.path.join(tmpdir, _COMPOSER_STATES_FILENAME) if dist.get_global_rank() == 0: # Only rank zero saves the composer state dict with open(composer_states_filepath, 'xb') as f: torch.save(state_dict, f) if state.is_model_deepspeed: state.deepspeed_model.save_checkpoint(tmpdir, _DEEPSPEED_TAG) # Move the checkpoint to the correct location checkpoint_dirname = os.path.dirname(checkpoint_filepath) if is_tar(checkpoint_filepath) and (state.is_model_deepspeed or dist.get_global_rank() == 0): # Either deepspeed (and every rank needs to call this), # or not deepspeed (but using an archive), in which case only rank zero should call this. if checkpoint_dirname: os.makedirs(checkpoint_dirname, exist_ok=True) write_mode = _get_write_mode(checkpoint_filepath) with tarfile.open(checkpoint_filepath, write_mode) as tarball: # add files flat to the tarball with the specified compression tarball.add(tmpdir, arcname="") elif dist.get_global_rank() == 0: # if not an archive, then only saving the states # only rank zero saves the state dict if checkpoint_dirname: os.makedirs(checkpoint_dirname, exist_ok=True) shutil.move(composer_states_filepath, checkpoint_filepath) else: checkpoint_filepath = None # Ensure that all processes wait for the checkpoint to be saved. dist.barrier() if checkpoint_filepath is not None: log.info('Saved checkpoint at %s', checkpoint_filepath) # Gather the paths across ranks. paths = dist.all_gather_object(checkpoint_filepath) paths = list(pathlib.Path(path) for path in paths if path is not None) return paths
save_checkpoint.__doc__ = f"""Checkpoint the training ``state``. Args: state (State): The training state. logger (Logger): The logger. filename (str): A format string describing how to name checkpoints. (default: ``'ep{{epoch}}-ba{{batch}}-rank{{rank}}'``) The following format variables are available: {textwrap.indent(FORMAT_NAME_WITH_DIST_AND_TIME_TABLE, prefix=' ')} .. note:: * By default, only the rank zero process will save a checkpoint file. * When using DeepSpeed, each rank will save a checkpoint file in tarball format. DeepSpeed requires tarball format, as it saves model and optimizer states in separate files. Ensure that ``'{{rank}}'`` appears within the ``filename``. Otherwise, multiple ranks may attempt to write to the same file(s), leading to corrupted checkpoints. If no tarball file extension is specified, ``.tar`` will be used. * To use compression (regardless of whether DeepSpeed is enabled), set the file extension to ``'.tar.gz'``, ``'.tgz'``, ``'.tar.bzip'``, or ``'.tar.lzma'`` (depending on the desired compression algorithm). .. warning:: Using compression will block the training loop while checkpoints are being compressed. As such, we recommend saving checkpoints without compression. Consider the following scenario, where: * The default ``name='ep{{epoch}}-ba{{batch}}-rank{{rank}}'`` is used. * The current epoch count is ``1``. * The current batch count is ``42``. When DeepSpeed is not being used, the rank zero process will save the checkpoint to ``'ep1-ba42-rank0'``. When DeepSpeed is being used, each rank (process) will save checkpoints to:: ep1-ba42-rank0.tar ep1-ba42-rank1.tar ep1-ba42-rank2.tar ... weights_only (bool, optional): If ``True``, save only the model weights instead of the entire training state. (default: ``False``) .. note:: When using DeepSpeed, this parameter must be ``False``. Weights-only checkpointing is not currently compatible with DeepSpeed, Returns: List[pathlib.Path]: The list of checkpoint files saved, indexed by the rank of the process. .. note:: When using DeepSpeed, each process (rank) saves its own checkpoint file. When doing multi-node training, the filepaths are valid only on each process's node; Composer does not move checkpoint files between nodes. Otherwise, when not using DeepSpeed, each list will contain only one filepath, since only the rank zero process saves checkpoints. """