Source code for composer.utils.checkpoint

# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Utilities for working with training checkpoints."""

from __future__ import annotations

import contextlib
import fnmatch
import logging
import os
import shutil
import tarfile
import tempfile
import textwrap
import warnings
from importlib import import_module
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Optional, Union

import torch
from packaging import version
from torch.distributed import checkpoint as dist_cp
from torch.distributed._tensor import DeviceMesh
from torch.distributed.checkpoint.metadata import Metadata
from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
from torch.distributed.checkpoint.planner import LoadPlan, LoadPlanner
from torch.distributed.distributed_c10d import ProcessGroup

from composer.utils import dist, reproducibility
from composer.utils.compression import get_compressor, is_compressed_pt
from composer.utils.file_helpers import (
from composer.utils.misc import ParallelismType, is_model_deepspeed, partial_format
from composer.utils.object_store import ObjectStore
from composer.utils.retrying import retry

    from composer.core import AlgorithmPass, State
    from composer.loggers import Logger, LoggerDestination

log = logging.getLogger(__name__)

__all__ = ['get_save_filename', 'load_checkpoint', 'save_checkpoint', 'download_checkpoint']

_DEEPSPEED_TAG = 'deepspeed'  # always tag with the same, deterministic name. We'll rename the tarball to the appropriate name.
_TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME = f'__{dist.get_global_rank()}_0.distcp'

def _get_checkpoint_validation_function(
) -> Optional[Callable[[Union[Path, str], Optional[list[tuple[int, int]]]], bool]]:
    """Get the validation function specified by the environment variable `CHECKPOINT_VALIDATION_FUNCTION`.

        Callable[[Union[Path, str], Optional[int], Optional[int]], bool] The checkpoint validation function that returns
            True given a valid checkpoint and optionally a list of offsets and lengths to check and False otherwise.
    name = os.environ.get('CHECKPOINT_VALIDATION_FUNCTION', None)
    if name is None:
        return None
    splits = name.split('.')
    module_name, fn_name = '.'.join(splits[:-1]), splits[-1]
    module = import_module(module_name)
    fn = getattr(module, fn_name)
    log.debug(f'Checkpoint validation function {name} has been found.')
    return fn

def _ensure_valid_checkpoint(checkpoint_filepath: Union[Path, str],
                             specs: Optional[list[tuple[int, int]]] = None) -> Union[Path, str]:
    """Ensures that the checkpoint at checkpoint_filepath is valid.

    using the function specified by the CHECKPOINT_VALIDATION_FUNCTION environment variable.
    If CHECKPOINT_VALIDATION_FUNCTION is not set, we skip validation.

        checkpoint_filepath (Union[Path,str]): The path to the checkpoint file.
        specs (Optional[list[tuple[int,int]]]): A list of offsets and lengths to check. Defaults to None.

        ValueError if checkpoint file is invalid.
    # Get the validation function by name.
    validate = _get_checkpoint_validation_function()

    # No function name has been specified.
    if validate is None:
        return checkpoint_filepath

    # Validate the checkpoint.
    if not validate(checkpoint_filepath, specs):
        raise ValueError(f'Checkpoint at {checkpoint_filepath} is invalid.')

    log.debug(f'Checkpoint at {checkpoint_filepath} is valid.')
    return checkpoint_filepath

def _torch_load_with_validation(checkpoint_filepath: Union[Path, str], map_location: str) -> Any:
    """Validates and loads a torch checkpoint.

        checkpoint_filepath (Union[Path,str]): The path to the checkpoint file.
        map_location (str): The location to load the checkpoint to.
    return torch.load(_ensure_valid_checkpoint(checkpoint_filepath), map_location=map_location)

def _format_path_with_rank_zero(path: str) -> str:
    """Formats ``path`` with the rank zero values."""
    return path.format(

def _format_path_with_current_rank(path: str) -> str:
    """Formats ``path`` formatted with the current rank values."""
    return path.format(

def _get_write_mode(name: str) -> str:
    """Get the write mode to use with :func:``."""
    if name.endswith('.tar'):
        return 'w'
    if name.endswith('.tar.gz') or name.endswith('.tgz'):
        return 'w:gz'
    if name.endswith('.tar.bz2'):
        return 'w:bz2'
    if name.endswith('.tar.lzma'):
        return 'w:xz'
    raise ValueError(f'{name} does not end with a valid tarfile extension.')

def _get_num_ranks_that_saved_rng(metadata: Metadata):
    rng_inds = []
    for field_name, field_value in metadata.planner_data.items():
        if 'rng' in field_name:
            _, rng_rank_index, _ = field_value
    rng_inds = set(rng_inds)
    return len(rng_inds)

class FileSystemReaderWithValidation(dist_cp.FileSystemReader):
    """FileSystemReader that validates checkpoint files prior to reading."""

    def __init__(self, path: str):
        if _get_checkpoint_validation_function() is None:
  'No checkpoint validation function found when loading sharded checkpoints.')

    def read_data(self, plan: LoadPlan, planner: LoadPlanner):
        """Reads data file.

            ValueError if the data file is invalid.
        path_to_specs: dict[str, list[tuple[int, int]]] = {}
        for read_item in plan.items:
            item_md = self.storage_data[read_item.storage_index]
            path = os.path.join(self.path, item_md.relative_path)
            path_to_specs.setdefault(path, []).append((item_md.offset, item_md.length))
        for path, spec in path_to_specs.items():
            _ensure_valid_checkpoint(path, spec)
        return super().read_data(plan, planner)

    def read_metadata(self) -> Metadata:
        """Reads metadata file.

            ValueError if the metadata file is invalid.
        metadata_file_path = os.path.join(self.path, '.metadata')
        return super().read_metadata()

def download_object_or_file(
    object_name: str,
    file_destination: Union[str, Path],
    object_store: Union[ObjectStore, LoggerDestination],
    if isinstance(object_store, ObjectStore):

# A subclass of FileSystemReaderWithValidation that downloads files from the object store before reading them from the local filesystem.
class DistCPObjectStoreReader(FileSystemReaderWithValidation):

    def __init__(
        source_path: str,
        destination_path: str,
        object_store: Union[ObjectStore, LoggerDestination],
        device_mesh: Optional[DeviceMesh],
        self.source_path = source_path
        self.destination_path = destination_path
        self.object_store = object_store
        self.device_mesh = device_mesh

        # Download metadata file.
        Path(self.destination_path).mkdir(parents=True, exist_ok=True)
        metadata_destination = os.path.join(self.destination_path, '.metadata')
        if dist.get_local_rank() == 0:
            metadata_path = str(Path(source_path) / Path('.metadata'))
            download_object_or_file(metadata_path, metadata_destination, object_store)

        # FileSystemReader takes in a root directory in its constructor, which is the dir where
        # the metadata is expected to be stored. Also, this is parent directory for any shard file relative paths
        # specified in the metadata file.

    def read_data(self, plan: LoadPlan, planner: LoadPlanner):
        # Download files if not using HSDP or if on first replica with HSDP enabled
        first_replica = True
        if self.device_mesh is not None and self.device_mesh.mesh_dim_names is not None and ParallelismType.DATA_PARALLEL_REPLICATE.value in self.device_mesh.mesh_dim_names:
            hsdp_index = self.device_mesh.mesh_dim_names.index(ParallelismType.DATA_PARALLEL_REPLICATE.value)
            first_replica = self.device_mesh.get_local_rank(mesh_dim=hsdp_index) == 0

        # 1. Collect the relative paths to download for all ranks for deduplication
        relative_file_paths = set()
        for plan_item in plan.items:
        all_file_paths = dist.all_gather_object(relative_file_paths)

        # 2. Download to the destination all files this rank needs if on first replica
        download_error = False
        if first_replica:
            log.debug(f'Rank {dist.get_global_rank()} starting to download files.')

            # Get the lowest rank in the current node
            local_rank_0 = dist.get_global_rank() - dist.get_local_rank()

                for plan_item in plan.items:
                    relative_file_path = self.storage_data[plan_item.storage_index].relative_path
                    # Check if the file is scheduled to be downloaded by a lower rank on the same node
                    # i.e. if rank 0 and rank 1 on the same node have the same the same required file,
                    # only rank 0 should download it and not rank 1.
                    is_downloaded = any(
                        relative_file_path in all_file_paths[i] for i in range(local_rank_0, dist.get_global_rank())

                    # Download the shard file to the relative path it's associated to and save that relative path
                    # to the root directory specified to the FileSystem reader constructor.
                    file_destination = str(Path(self.destination_path) / Path(relative_file_path))

                    # The file could have already been downloaded as different plan items can point to same file.
                    if not is_downloaded and not os.path.exists(file_destination):
                        log.debug(f'Downloading {relative_file_path} to {file_destination}.')
                        object_name = str(Path(self.source_path) / Path(relative_file_path))
                        download_object_or_file(object_name, file_destination, self.object_store)
                        log.debug(f'Finished downloading {relative_file_path} to {file_destination}.')
            except Exception as e:
                log.error(f'Exception {type(e)} raised during downloading: {str(e)}')
                download_error = True

        # PyTorch will capture any exception of this function,
        # and dist.all_gather_objects(exception) before raising it.
        # If that all_gather_objects fails, the exception is never visible to user.
        # We raise the exception from all ranks to ensure the user sees it.
        download_error_tensor = dist.get_device(None).tensor_to_device(torch.tensor(1 if download_error else 0))
        error_by_rank = dist.all_gather(download_error_tensor)
        failed_ranks = []
        for rank, error in enumerate(list(error_by_rank)):
            if error > 0:
                download_error = True

        if download_error:
            raise RuntimeError(
                f'Ranks {failed_ranks} failed to download.',
                'To see the full error please look at the logs for that rank, which are logged via log.error.',

        # 3. Wait for all ranks to finish.
        log.debug(f'Rank {dist.get_global_rank()} finished downloading all files.')
        log.debug('Done waiting for all ranks to finish downloading files.')

        # 4. Broadcast files to all other replicas if HSDP
        if self.device_mesh is not None and self.device_mesh.mesh_dim_names is not None and ParallelismType.DATA_PARALLEL_REPLICATE.value in self.device_mesh.mesh_dim_names:
            # Broadcast file to all replicas
            replicate_index = self.device_mesh.mesh_dim_names.index(ParallelismType.DATA_PARALLEL_REPLICATE.value)
            shard_index = self.device_mesh.mesh_dim_names.index(ParallelismType.DATA_PARALLEL_SHARD.value)
            replicate_process_group = self.device_mesh.get_group(replicate_index)
            shard_process_group = self.device_mesh.get_group(shard_index)
            shard_size = self.device_mesh.size(shard_index)
            rank_in_first_replica = dist.get_global_rank() % shard_size
            sender = dist.get_global_rank() == rank_in_first_replica
            receiver = dist.get_global_rank() != rank_in_first_replica

            # Send list of files to all ranks
            file_list = [[
                file_name for file_name in sorted(os.listdir(self.destination_path)) if file_name.endswith('.distcp')
            dist.broadcast_object_list(file_list, src=rank_in_first_replica, group=replicate_process_group)
            file_list = file_list[0]
            log.debug(f'list of files to broadcast: {file_list}')

            # Send each file to the appropriate rank
            for file_name in file_list:
                if dist.get_local_rank() == 0 or (
                    dist.get_global_rank(shard_process_group) == 0  # pyright: ignore[reportGeneralTypeIssues]
                ):  # Only 1 rank per node needs to transfer file
                    full_path = os.path.join(self.destination_path, file_name)
                    log.debug(f'Transferring {full_path=}')
                    file_object = [None]
                    if sender:
                        with open(full_path, 'rb') as f:
                            file_object = [{'content':}]
                        src=dist.get_global_rank() % shard_size,
                    received_file_object = file_object[0]
                    assert received_file_object is not None
                    if receiver and not os.path.exists(full_path) and dist.get_local_rank() == 0:
                        with open(full_path, 'wb') as f:

            log.debug(f'Rank {dist.get_global_rank()} finished transferring files to all ranks.')
                f'Done waiting for all ranks to finish transferring files. Local checkpoint files: {sorted(os.listdir(self.destination_path))}',

        # 5. Piggyback off of the FileSystemReader to read all the files now that they are downloaded.
        return super().read_data(plan, planner)

[docs]class PartialFilePath: def __init__(self, filename: str, folder: Optional[str] = None): self.folder = folder self.filename = filename def format(self, state: State, is_deepspeed: bool = False, keep_placeholders: bool = False) -> str: # if filename already has a suffix (e.g., this would append to be extra_suffix = '.tar' if is_deepspeed and not is_tar(self.filename) else '' if self.folder: if keep_placeholders: return os.path.join( self.folder, self.filename, ) + extra_suffix else: return os.path.join( format_name_with_dist(self.folder, state.run_name), format_name_with_dist_and_time(self.filename, state.run_name, state.timestamp), ) + extra_suffix else: if keep_placeholders: return self.filename + extra_suffix else: return format_name_with_dist_and_time( self.filename, state.run_name, state.timestamp, ) + extra_suffix
def is_checkpoint_legacy_sharded(object_store: Optional[Union[LoggerDestination, ObjectStore]], source_path: str): if source_path.endswith('.symlink') or os.path.islink(source_path): source_path = extract_path_from_symlink(source_path, object_store=object_store) metadata_path = str(Path(source_path) / Path('.metadata')) if object_store is None: return not os.path.exists(metadata_path) else: try: _, _, metadata_path = parse_uri(metadata_path) with tempfile.TemporaryDirectory() as temp_dir: metadata_destination = os.path.join(str(temp_dir), '.metadata') download_object_or_file(metadata_path, metadata_destination, object_store) return False except FileNotFoundError: return True
[docs]def load_checkpoint( path: str, state: State, logger: Logger, object_store: Optional[Union[ObjectStore, LoggerDestination]] = None, load_weights_only: bool = False, strict_model_weights: bool = True, progress_bar: bool = True, ignore_keys: Optional[Union[list[str], Callable[[dict], None]]] = None, exclude_algorithms: Optional[list[str]] = None, algorithm_passes: Optional[list[AlgorithmPass]] = None, ): """Load a checkpoint from a local file, URI, or cloud object store into ``state``. Args: path (str): The path format string to an existing checkpoint file. It can be a path to a file on the local disk, a URL, or if ``object_store`` is set, the object name for a checkpoint in a cloud bucket. When using `Deepspeed ZeRO <>`_, checkpoints are sharded by rank. Instead of hard-coding the rank in the ``path``, use the following format variables: +------------------------+-------------------------------------------------------+ | Variable | Description | +========================+=======================================================+ | ``{rank}`` | The global rank, as returned by | | | :func:`~.dist.get_global_rank`. | +------------------------+-------------------------------------------------------+ | ``{local_rank}`` | The local rank of the process, as returned by | | | :func:`~.dist.get_local_rank`. | +------------------------+-------------------------------------------------------+ | ``{node_rank}`` | The node rank, as returned by | | | :func:`~.dist.get_node_rank`. | +------------------------+-------------------------------------------------------+ For example, suppose that checkpoints are stored in the following structure: .. code-block:: my_model/ep1-rank0.tar my_model/ep1-rank1.tar my_model/ep1-rank2.tar ... Then, ``path`` should be set to ``my_model/ep1-rank{rank}.tar``, and all ranks will load the correct state. state (State): The :class:`~composer.core.State` to load the checkpoint into. logger (Logger): The :class:`~composer.logger.Logger` to log any information. object_store (Union[ObjectStore, LoggerDestination], optional): If the ``path`` is in an object store (i.e. AWS S3 or Google Cloud Storage), an instance of :class:`~.ObjectStore` or :class:`~.LoggerDestination` which will be used to retrieve the checkpoint. Otherwise, if the checkpoint is a local filepath, set to ``None``. (default: ``None``) load_weights_only (bool, optional): Whether or not to only restore the model weights from the checkpoint without restoring the associated state. (default: ``False``) strict_model_weights (bool, optional): Whether or not to force that the checkpointed weights must exactly match the model weights. (default: ``True``) progress_bar (bool, optional): Whether or not to show a progress bar when downloading checkpoints. Ignored if the checkpoint is a local file path. (default: ``True``) ignore_keys (list[str] | (dict) -> None, optional): A list of paths for the ``state_dict`` of the checkpoint, which, when provided, will be ignored from the state_dict before a checkpoint is loaded. Each path is a list of strings specifying the keys to index into ``state_dict`` joined together with `/` as a separator (as PyTorch uses `.` in parameter names). If a prefix is provided, all children are also ignored (see Example 2). See :mod:`composer.core.state` for the structure of state_dict. Example 1: ``ignore_keys = ["state/model/layer1.weights", "state/model/layer1.bias"]`` would ignore layer 1 weights and bias. Example 2: ``ignore_keys = ["state/model/*"]`` would ignore the entire model, which would have the same effect as the previous example if there was only 1 layer. Example 3: ``ignore_keys = ["state/model/layer*.weights"]`` would ignore all weights in the model. Example 4: ``ignore_keys = ["state/rank_zero_seed", "rng"]`` would reset all randomness when loading the checkpoint. If a callable, it should take one argument which is the state_dict. The callable is free to arbitrarily modify the state_dict before it is loaded. (default: ``None``) exclude_algorithms (list[str], optional): A list of algorithm names to exclude from loading. By default, algorithms with `required_on_load=True` which were enabled when training the loaded checkpoint are automatically applied unless they conflict with a user specified algorithm. These algorithms often change the model, and not applying them could result in certain layers not having weights loaded. Example 1: ``exclude_algorithms = ["BlurPool"]`` would exclude BlurPool from loading. Example 2: ``exclude_algorithms = ["FusedLayerNorm", "Alibi"]`` would exclude FusedLayerNorm and Alibi from loading. (default: ``None``) algorithm_passes (list[AlgorithmPass], optional): A list of algorithm passes to apply to autoloaded algorithms to sort them into the correct order. (default: ``None``) Returns: Optional[list[dict[str, Any]]]: The RNG state dicts, indexed by global rank, if :attr:`load_weights_only` is not None. Otherwise, None. """ path = partial_format(path, run_name=state.run_name) using_legacy_sharded = False if state.fsdp_sharded_state_dict_enabled: assert object_store is None or isinstance( object_store, ObjectStore, ), 'For loading sharded checkpoints load_object_store must be set with the class ObjectStore' using_legacy_sharded = is_checkpoint_legacy_sharded(object_store, path) if state.fsdp_sharded_state_dict_enabled and not using_legacy_sharded: rng_state_dicts = load_sharded_checkpoint( source_path=path, state=state, logger=logger, object_store=object_store, load_weights_only=load_weights_only, strict_model_weights=strict_model_weights, progress_bar=progress_bar, ignore_keys=ignore_keys, exclude_algorithms=exclude_algorithms, algorithm_passes=algorithm_passes, ) else: # Download the checkpoint to the node-local folder log.debug('Loading checkpoint at %s', path) # Each node gets one unique folder to store checkpoints that is shared amongst all local ranks in that node. # If fsdp sharded state_dicts is enabled then EVERY rank gets a unique checkpoint folder. needs_unique_checkpoint_folder = state.fsdp_sharded_state_dict_enabled or dist.get_local_rank() == 0 tempdir_ctx = tempfile.TemporaryDirectory() if needs_unique_checkpoint_folder else contextlib.nullcontext(None) with tempdir_ctx as tempdir: try: # Get the path to the proper checkpoint folder corresponding to the current rank's node. # If fsdp_sharded_state_dict_enabled then just use that rank's unique tempdir. node_checkpoint_folder = ( tempdir if state.fsdp_sharded_state_dict_enabled else _get_local_rank_zero_path(tempdir) ) assert node_checkpoint_folder is not None composer_states_filepath, extracted_checkpoint_folder, extracted_rank_n = download_checkpoint( path=path, node_checkpoint_folder=node_checkpoint_folder, object_store=object_store, progress_bar=progress_bar, fsdp_sharded_state_dict_enabled=state.fsdp_sharded_state_dict_enabled, deepspeed_sharded_checkpoint=is_model_deepspeed(state.model), ) rng_state_dicts = _restore_checkpoint( state, logger, composer_states_filepath, extracted_rank_n, extracted_checkpoint_folder, load_weights_only=load_weights_only, strict_model_weights=strict_model_weights, ignore_keys=ignore_keys, exclude_algorithms=exclude_algorithms, algorithm_passes=algorithm_passes, ) finally: # Wait for all ranks to finish restoring the checkpoint before releasing the tempdir, since tempdir can # be a shared resource between nodes. dist.barrier()'%s loaded from %s', 'Model weights' if load_weights_only else 'Trainer checkpoint', path) step_to_resume_from = state.timestamp.batch.value max_step_to_resume_from = state.device.tensor_to_device( torch.tensor(state.timestamp.batch.value, dtype=torch.int64), ) min_step_to_resume_from = state.device.tensor_to_device( torch.tensor(state.timestamp.batch.value, dtype=torch.int64), ) dist.all_reduce(max_step_to_resume_from, reduce_operation='MAX') dist.all_reduce(min_step_to_resume_from, reduce_operation='MIN') if != raise RuntimeError( textwrap.dedent( f'Timestamp mismatch error: batch to resume from {step_to_resume_from} is not the same on all ranks. ' 'This usually occurs when at least one rank fails to save the last checkpoint ' 'while using sharded checkpointing + autoresume. ' 'Please manually resume by disabling autoresume and explicitly setting load_path ' 'to the most recent checkpoints that all ranks have saved. ' 'E.g. for the 10th batch: trainer = Trainer(autoresume=False, load_path="/path/to/checkpoint/ba10-rank{rank}.pt", ...). ' 'Remember to keep the {rank} placeholder!', ), ) return rng_state_dicts
def load_sharded_checkpoint( source_path: str, state: State, logger: Logger, object_store: Optional[Union[ObjectStore, LoggerDestination]] = None, load_weights_only: bool = False, strict_model_weights: bool = False, progress_bar: bool = True, ignore_keys: Optional[Union[list[str], Callable[[dict], None]]] = None, exclude_algorithms: Optional[list[str]] = None, algorithm_passes: Optional[list[AlgorithmPass]] = None, ) -> Union[list[dict], None]: using_multinode = dist.get_world_size() != dist.get_local_world_size() if not version.parse(torch.__version__) >= version.parse('2.0.1') and using_multinode: raise ValueError( f'Sharded checkpoint loading on >1 node requires torch version >= 2.0.1. You have torch version {torch.__version__}', ) if state.fsdp_config is None: raise ValueError('Loading a sharded checkpoint requires passing an FSDP config to Trainer.') # Check to make sure source_path is a directory. if object_store is None: if os.path.islink(source_path): source_path = os.path.join(os.path.dirname(source_path), os.readlink(source_path)) if os.path.exists(source_path): if not os.path.isdir(source_path): raise ValueError(f'load_path must be a directory when using sharded state dict. Got {source_path}') else: raise FileNotFoundError(f'{source_path} not found!') download_dir_context = tempfile.TemporaryDirectory if object_store is not None else contextlib.nullcontext with download_dir_context() as temp_download_dir: if object_store is not None: # Get the tempfile made on local rank 0. local_rank0_index = dist.get_global_rank() - dist.get_local_rank() rank0_download_tempdir = str(dist.all_gather_object(temp_download_dir)[local_rank0_index]) if source_path.endswith('.symlink'): source_path = extract_path_from_symlink(source_path, object_store=object_store) storage_reader = DistCPObjectStoreReader( source_path=source_path, destination_path=str(Path(rank0_download_tempdir) / Path('checkpoints')), object_store=object_store, device_mesh=state.device_mesh, ) else: storage_reader = FileSystemReaderWithValidation(source_path) # We need no_grad because we overwrite tensor values with set_() when we do elastic loading and we don't want the set_ op recorded in the computation graph. with torch.no_grad(): # 1. Load metadata first for backwards compatability check # We need to check if the "optimizers" is at the root of the state dict to determine # how to load the optimizer state. try: metadata = storage_reader.read_metadata() except AttributeError as e: if '_MEM_FORMAT_ENCODING' in str(e): raise ValueError( 'Unable to read checkpoint metadata. The checkpoint was likely saved with a ' 'newer version of torch. Upgrade your torch version to load this checkpoint.', ) else: raise # Retrieve all top-level keys of the metadata. top_level_keys = [v[0] for v in metadata.planner_data.values()] optimizers_at_root = 'optimizers' in top_level_keys # 2. Load model and metadata if load_weights_only: state_dict: dict[str, Any] = {'state': {'model': state.get_model_state_dict()}} else: cur_state_dict = state.state_dict() # If 'optimizers' is at root-level, we load it separately. if optimizers_at_root: cur_state_dict.pop('optimizers') num_rng_ranks = _get_num_ranks_that_saved_rng(storage_reader.read_metadata()) state_dict: dict[str, Any] = { 'state': cur_state_dict, 'rng': reproducibility.get_rng_state()[:num_rng_ranks], } if ignore_keys: # Filter provided list of key paths if not callable(ignore_keys): ignore_keys = glob_filter(ignore_keys) # Call function to modify state_dict ignore_keys(state_dict) # Ensure state exists state_dict['state'] = state_dict.get('state', {}) # dist_cp.load breaks unless the specified state_dict supports `load_state_dict` # See: dist_cp.load_state_dict( state_dict=state_dict, storage_reader=storage_reader, planner=state.fsdp_config.load_planner, no_dist=(not dist.is_initialized()), )'Loaded state dict') state.load_state_dict( state_dict['state'], logger, strict=strict_model_weights, exclude_algorithms=exclude_algorithms, algorithm_passes=algorithm_passes, ) # 3. Optionally load optimizer # If 'optimizers' was not at root-level, then it will already be loaded if optimizers_at_root and not load_weights_only: optim_state = load_sharded_optimizer_state_dict( model_state_dict=state.state_dict()['model'], optimizer_key='optimizers', storage_reader=storage_reader, ) state._legacy_load_optim_state(optim_state) return state_dict.get('rng', None) def _get_local_rank_zero_path(path: Optional[str]) -> str: """Broadcasts the ``path`` from the LOCAL rank zero to all LOCAL ranks.""" local_rank_zero = dist.get_global_rank() - dist.get_local_rank() paths = dist.all_gather_object(path) local_rank_zero_path = paths[local_rank_zero] assert local_rank_zero_path is not None, 'local rank zero provides the path' return local_rank_zero_path def download_checkpoint( path: str, node_checkpoint_folder: str, object_store: Optional[Union[ObjectStore, LoggerDestination]], progress_bar: bool, fsdp_sharded_state_dict_enabled: bool = False, deepspeed_sharded_checkpoint: bool = False, ) -> tuple[str, Optional[str], bool]: """Download the checkpoint stored at ``path``, potentially in ``object_store``, to ``node_checkpoint_folder``. Returns a tuple of (``composer_states_filepath``, ``extracted_checkpoint_folder``, ``extracted_rank_n``). * The ``composer_states_filepath``, is the path to the composer states, which can be passed into :meth:`torch.load`. * The ``extracted_checkpoint_folder`` is the path to the checkpoint folder, which can be passed into :meth:`deepspeed.DeepSpeedEngine.load_checkpoint`. * The ``extracted_rank_n`` is a boolean flag indicating whether a tarball was extracted on global rank greater than 0. """ log.debug('Downloading checkpoint to folder %s', node_checkpoint_folder) rank_zero_checkpoint_filepath = os.path.join(node_checkpoint_folder, 'rank0_checkpoint') rank_n_checkpoint_filepath = os.path.join(node_checkpoint_folder, f'rank{dist.get_global_rank()}_checkpoint') extracted_checkpoint_folder = None extracted_rank_n = False if is_tar(path): extracted_checkpoint_folder = os.path.join(node_checkpoint_folder, 'checkpoint') composer_states_filepath = os.path.join(extracted_checkpoint_folder, _COMPOSER_STATES_FILENAME) else: # it's not an archive; it's just the composer state dict # and only rank zero has this file unless fsdp_sharded_state_dict_enabled then # every rank has it's own file. extracted_checkpoint_folder = None composer_states_filepath = ( rank_n_checkpoint_filepath if fsdp_sharded_state_dict_enabled else rank_zero_checkpoint_filepath ) if is_compressed_pt(path): original_path = path path = os.path.splitext(path)[0] compressor = get_compressor(original_path) with open(path, 'wb') as out_file: with compressor.decompress(original_path) as in_file: shutil.copyfileobj(in_file, out_file) checkpoint_is_sharded = fsdp_sharded_state_dict_enabled or deepspeed_sharded_checkpoint try: if not checkpoint_is_sharded and dist.get_local_rank() == 0: # If the checkpoint is not sharded, then local rank 0 on each node needs to download the # global rank 0 checkpoint path = _format_path_with_rank_zero(path) get_file( destination=rank_zero_checkpoint_filepath, path=path, object_store=object_store, progress_bar=progress_bar, ) if extracted_checkpoint_folder is not None: try: with as tarball: tarball.extractall(extracted_checkpoint_folder) except FileNotFoundError: # Not re-raising the file-not-found error as that is irrelevant; # the underlying issue is that the checkpoint file does not exist on the disk # or could not be downloaded raise RuntimeError(f'Checkpoint {path} does not exist') elif checkpoint_is_sharded: # If the checkpoint is sharded, then every rank needs to download its own checkpoint path = _format_path_with_current_rank(path) try: get_file( destination=rank_n_checkpoint_filepath, path=path, object_store=object_store, progress_bar=progress_bar, ) except FileNotFoundError as e: raise FileNotFoundError(( f'Checkpoint {path} does not exist, but is required for sharded checkpointing ' f'on rank {dist.get_global_rank()}. Please ensure that the checkpoint exists ' 'and your load_path was specified as a format string with the {rank} argument.' )) from e if extracted_checkpoint_folder is not None: try: # it's an archive and needs to be extracted with as tarball: tarball.extractall(extracted_checkpoint_folder) extracted_rank_n = True except FileNotFoundError: # this will happen most of the time (i.e. whenever deepspeed # is not being used) so not logging anything pass finally: # Use busy wait to avoid timeouts on large downloads for non-sharded checkpoints if not checkpoint_is_sharded: signal_file_path = os.path.join( node_checkpoint_folder, f'.node_{dist.get_node_rank()}_local_rank0_completed', ) if dist.get_local_rank() == 0: with open(signal_file_path, 'wb') as f: f.write(b'local_rank0_completed') # Avoid the collective call until the local rank zero has finished trying to download the # checkpoint so that we don't timeout for large downloads. This syncs all processes on the # node with dist.local_rank_zero_download_and_wait(signal_file_path): # Then, wait to ensure every node has finished downloading the checkpoint dist.barrier() if dist.get_local_rank() == 0: os.remove(signal_file_path) dist.barrier() return composer_states_filepath, extracted_checkpoint_folder, extracted_rank_n def _flatten_keys(obj: Any, paths: list[str], existing_path: str): """Recursively flatten the keys of a dictionary or list into a set of paths.""" # Store path when we reach end, which is either non-dict or empty dict if isinstance(obj, list) and len(obj) > 0: for i, elm in enumerate(obj): _flatten_keys(elm, paths, f'{existing_path}/{i}') elif isinstance(obj, dict) and len(obj) > 0: for k, v in obj.items(): _flatten_keys(v, paths, f'{existing_path}/{k}') # Remove leading / paths.append(existing_path.lstrip('/')) def _remove_paths(obj: Union[list, dict[str, Any]], exclude_paths: list[list[str]]): # Build str(key) to key map to undo cast from glob filtering. Despite typing, some state_dict # keys are not strings, so we need to cast them back to their original type. str_key_to_key = {} if isinstance(obj, dict): for key in obj.keys(): str_key_to_key[str(key)] = key # First determine the keys which will be recursed on and which will be removed entirely # Group the `exclude_paths` by the key keys_to_recurse = {} keys_to_remove = [] for exclude_path_parts in exclude_paths: key = exclude_path_parts[0] # Cast list indices to int if isinstance(obj, list): key = int(key) # Un-str dict keys if necessary if key in str_key_to_key: key = str_key_to_key[key] if len(exclude_path_parts) == 1: keys_to_remove.append(key) else: if key not in keys_to_recurse: keys_to_recurse[key] = [] keys_to_recurse[key].append(exclude_path_parts[1:]) # Recurse first, so in the case of a list, the indexing is consistent for key, paths_to_recurse in keys_to_recurse.items(): _remove_paths(obj[key], paths_to_recurse) # Sort the keys in reverse order, so in the case of a list, the indexing is consistent keys_to_remove.sort(reverse=True) # Remove the keys for key in keys_to_remove: del obj[key] def glob_filter(exclude_globs: list[str]) -> Callable[[dict], None]: """Provides a function which deletes all subparts of a dictionary based on a list of paths.""" def filter_func(state_dict: dict) -> None: # Flatten dictionary into paths paths = [] _flatten_keys(state_dict, paths, '/') filtered_paths = [] for exclude_glob in exclude_globs: filtered_paths_from_glob = fnmatch.filter(paths, exclude_glob) if len(filtered_paths_from_glob) == 0: warnings.warn( f'No parts from loaded checkpoint state_dict were ignored by load_ignore_key {exclude_glob}', ) filtered_paths.extend(filtered_paths_from_glob) filtered_paths = list(set(filtered_paths)) if filtered_paths: filtered_paths_str = ', '.join(filtered_paths)'Ignoring the following paths from the loaded checkpoint state_dict: {filtered_paths_str}') # Loop through all paths to exclude paths_to_remove = [path.split('/') for path in filtered_paths if len(path) > 0] _remove_paths(state_dict, paths_to_remove) return filter_func
[docs]def safe_torch_load( composer_states_filepath: Union[Path, str], map_location: str = 'cpu', load_monolith_rank0_only: bool = False, ) -> dict[str, Any]: """Load a torch checkpoint, catching errors due to backwards compatibility issues. Args: composer_states_filepath: The path to the checkpoint file. map_location: The location to load the checkpoint to. load_monolith_rank0_only: Whether the checkpoint is a monolith FSDP checkpoint. """ try: if load_monolith_rank0_only: 'Loading monolith FSDP checkpoint. Only rank 0 will load and broadcast non-weight/optimizer state.', ) state_dict_list = [None] model = None optimizer = None if dist.get_global_rank() == 0: state_dict_list[0] = _torch_load_with_validation(composer_states_filepath, map_location=map_location) # Don't broadcast model/optimizer state if they exist if 'model' in state_dict_list[0]['state']: model = state_dict_list[0]['state']['model'] state_dict_list[0]['state']['model'] = None if 'optimizers' in state_dict_list[0]['state']: optimizer = state_dict_list[0]['state']['optimizers'] state_dict_list[0]['state']['optimizers'] = None log.debug('Broadcasting state_dict to all ranks.') dist.broadcast_object_list(state_dict_list, src=0) state_dict: dict[str, Any] = state_dict_list[0] # type: ignore if dist.get_global_rank() == 0: if model is not None: state_dict['state']['model'] = model if optimizer is not None: state_dict['state']['optimizers'] = optimizer return state_dict else: return _torch_load_with_validation(composer_states_filepath, map_location=map_location) except TypeError as e: if 'Accuracy.__new__() missing 1 required positional argument' in str(e): raise Exception( 'As of v0.10.0, torchmetrics introduces a new required argument to Accuracy which ' 'breaks backwards compatibility. Unfortunately, this means that older checkpoints ' 'cannot be loaded with the metrics. In order to successfully load this model, please ' 'pass `load_ignore_keys = ["state/train_metrics/*", "state/eval_metrics/*"]`.', ) from e raise e
def _restore_checkpoint( state: State, logger: Logger, composer_states_filepath: str, extracted_rank_n: bool, extracted_checkpoint_folder: Optional[str], load_weights_only: bool, strict_model_weights: bool, ignore_keys: Optional[Union[list[str], Callable[[dict], None]]], exclude_algorithms: Optional[list[str]], algorithm_passes: Optional[list[AlgorithmPass]], ) -> Optional[list[dict[str, Any]]]: """Restore a checkpoint into ``state`` and returns the rng state dicts (if ``load_weights_only`` is False).""" # Now, all ranks load the checkpoint that local rank zero downloaded state_dict = safe_torch_load( composer_states_filepath=composer_states_filepath, load_monolith_rank0_only=state.load_monolith_rank0_only, ) if ignore_keys: # Filter provided list of key paths if not callable(ignore_keys): ignore_keys = glob_filter(ignore_keys) # Call function to modify state_dict ignore_keys(state_dict) # Ensure state exists state_dict['state'] = state_dict.get('state', {}) log.debug(f"Loaded checkpoint with keys {state_dict.keys()} and state keys {state_dict['state'].keys()}") if is_model_deepspeed(state.model): if extracted_checkpoint_folder is None: raise RuntimeError('Deepspeed checkpoints require a tarball, not a weights file.') global_rank = dist.get_global_rank() if global_rank > 0 and not extracted_rank_n: raise RuntimeError(f'Deepspeed checkpoint missing for rank {global_rank}') load_path, _ = state.deepspeed_model.load_checkpoint( extracted_checkpoint_folder, tag=_DEEPSPEED_TAG, load_module_only=load_weights_only, load_module_strict=strict_model_weights, ) if load_path is None: raise RuntimeError(f'Failed to load DeepSpeed checkpoint') elif load_weights_only: state.load_model_state( state_dict['state'], logger, strict=strict_model_weights, exclude_algorithms=exclude_algorithms, algorithm_passes=algorithm_passes, ) if not load_weights_only: state.load_state_dict( state_dict['state'], logger, exclude_algorithms=exclude_algorithms, algorithm_passes=algorithm_passes, ) return state_dict.get('rng', None)
[docs]def get_save_filename( state: State, filename: str = 'ep{epoch}-ba{batch}-rank{rank}', ) -> str: """Gets full filename of save filename. Args: state (State): The :class:`~composer.core.State` to load the checkpoint into. filename (filename): The name of the save file. Returns: Full filename of save file. """ if not state.fsdp_sharded_state_dict_enabled: is_deepspeed = is_model_deepspeed(state.model) return PartialFilePath(filename).format(state, is_deepspeed) # Sharded checkpoints get their own little folder. assert state.fsdp_config is not None remote_prefix = state.fsdp_config.sharded_ckpt_prefix_dir assert remote_prefix is not None save_dirpath = Path(Path(filename).parent) / Path(remote_prefix) save_dirpath = format_name_with_dist_and_time(str(save_dirpath), state.run_name, state.timestamp) # New name is now Trainer.save_folder / sharded_ckpt_prefix_dir / __{dist.get_global_rank()}_0.distcpโ€™ # e.g. path/to/my/checkpoints/ep1-ba2/__1_0.distcp ckpt_filename = _TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME return str(Path(save_dirpath) / Path(ckpt_filename))
def _save_checkpoint( state: State, save_filename: str, *, weights_only: bool = False, ignore_keys: Optional[Union[list[str], Callable[[dict], None]]] = None, ) -> Union[str, None]: # noqa: D103 is_deepspeed = is_model_deepspeed(state.model) if weights_only and not is_deepspeed: state_dict = { 'state': { 'model': state.get_model_state_dict(), 'integrations': state._get_integrations_state_dict(), 'metadata': state._get_state_metadata(), }, } else: state_dict = { 'state': state.state_dict(), 'rng': reproducibility.get_rng_state(), } if ignore_keys: # Filter provided list of key paths if not callable(ignore_keys): ignore_keys = glob_filter(ignore_keys) # Call function to modify state_dict ignore_keys(state_dict) # Ensure state exists state_dict['state'] = state_dict.get('state', {}) if state.fsdp_sharded_state_dict_enabled and not weights_only: # Only rank 0 saves RNG if dist.get_global_rank() > 0: state_dict.pop('rng') # To load optimizer states with 2.0 <= torch < 2.2.3 , the optimizer state must be at the top # level of the state dict because the load_sharded_optimizer_state_dict function # requires a top level state dict key for the optimizer. # See # for more info. if version.parse(torch.__version__) < version.parse('2.2.3'): state_dict['optimizers'] = state_dict['state'].pop('optimizers') log.debug('State dict created.') dirname = os.path.dirname(save_filename) if dirname: os.makedirs(dirname, exist_ok=True) # Only some ranks are meant to save checkpoint and produce a file expect_file = False # Save deepspeed checkpoint if is_deepspeed: expect_file = True log.debug('Saving deepspeed checkpoints to %s...', save_filename) if dist.get_global_rank() == 0: _write_checkpoint_file(state_dict, save_filename) _save_deepspeed_model(state.deepspeed_model, save_filename) # Save sharded checkpoint elif state.fsdp_sharded_state_dict_enabled: if state.fsdp_config is None: raise ValueError('Saving a sharded checkpoint requires passing an FSDP config to Trainer.') log.debug(f'Saving sharded checkpoints to {save_filename}...') process_group = None device_mesh = state.device_mesh if device_mesh is not None and device_mesh.mesh_dim_names is not None and ParallelismType.DATA_PARALLEL_REPLICATE.value in device_mesh.mesh_dim_names: # If hybrid shard, only rank in first replica saves hsdp_index = device_mesh.mesh_dim_names.index(ParallelismType.DATA_PARALLEL_REPLICATE.value) expect_file = device_mesh.get_local_rank(mesh_dim=hsdp_index) == 0 if expect_file: process_group = device_mesh.get_group(1) # Shard process_group for first replica assert isinstance(process_group, ProcessGroup) # For type checker log.debug(f'Saving on global_rank={dist.get_global_rank()}, {expect_file=}') else: expect_file = True if expect_file: if version.parse(torch.__version__) >= version.parse('2.3.0'): save_planner = state.fsdp_config.save_planner if save_planner is None: from composer.trainer._patch_pytorch import SavePlannerWithDedupFix save_planner = SavePlannerWithDedupFix() state_dict=state_dict, storage_writer=dist_cp.FileSystemWriter(dirname), planner=save_planner, process_group=process_group, ) else: dist_cp.save_state_dict( state_dict=state_dict, storage_writer=dist_cp.FileSystemWriter(dirname), planner=state.fsdp_config.save_planner, process_group=process_group, ) log.debug('Finished pytorch save state dict') # Save monolith checkpoint elif dist.get_global_rank() == 0: expect_file = True log.debug(f'Saving monolithic checkpoint to {save_filename}') _write_checkpoint_file(state_dict, save_filename) log.debug(f'Global rank 0 done saving checkpoint to disk at {save_filename}.') else: log.debug(f'Only rank 0 is saving a checkpoint, so rank {dist.get_global_rank()} skips checkpointing.') dist.barrier() # Ensure all ranks saved their files if expect_file: assert os.path.exists(save_filename), 'Expected file to have been saved.' return save_filename else: # no file saved return None def _write_checkpoint_file(state_dict: dict[str, Any], filename: str) -> None: """Write the given checkpoint state to the given path. Compressing if indicated to do so by the file extension.""" if is_tar(filename): log.debug('Writing checkpoint tar file %s', filename) write_mode = _get_write_mode(filename) with tempfile.TemporaryDirectory(prefix='checkpoint') as tmpdir: with open(os.path.join(tmpdir, _COMPOSER_STATES_FILENAME), 'wb') as f:, f) with, write_mode) as tarball: tarball.add(tmpdir, arcname='') elif is_compressed_pt(filename): log.debug('Writing compressed checkpoint %s', filename) compressor = get_compressor(filename) with compressor.compress(filename) as f:, f) else: log.debug('Writing uncompressed checkpoint %s', filename) with open(filename, 'wb') as f:, f) def _save_deepspeed_model(model, filename: str): """Save Deepspeed model and tarball the files.""" write_mode = _get_write_mode(filename) read_mode = 'r' + write_mode[1:] with tempfile.TemporaryDirectory() as tmpdir: model.save_checkpoint(tmpdir, _DEEPSPEED_TAG) if os.path.exists(filename): # extract to tmpdir to append below # not all compression formats support direct append with, read_mode) as tar: tar.extractall(tmpdir) with, write_mode) as tar: tar.add(tmpdir, arcname='')
[docs]def save_checkpoint( state: State, filename: str = 'ep{epoch}-ba{batch}-rank{rank}', *, weights_only: bool = False, ignore_keys: Optional[Union[list[str], Callable[[dict], None]]] = None, ) -> Union[str, None]: # noqa: D103 # Clear the cache in case we are near the memory limit to give some space for NCCL. torch.cuda.empty_cache() save_filename = get_save_filename(state, filename) return _save_checkpoint(state, save_filename, weights_only=weights_only, ignore_keys=ignore_keys)
save_checkpoint.__doc__ = f"""Checkpoint the training ``state``. Args: state (State): The training state. logger (Logger): The logger. filename (str): A format string describing how to name checkpoints. (default: ``'ep{{epoch}}-ba{{batch}}-rank{{rank}}'``) The following format variables are available: {textwrap.indent(FORMAT_NAME_WITH_DIST_AND_TIME_TABLE, prefix=' ')} .. note:: * By default, only the rank zero process will save a checkpoint file. * When using DeepSpeed, each rank will save a checkpoint file in tarball format. DeepSpeed requires tarball format, as it saves model and optimizer states in separate files. Ensure that ``'{{rank}}'`` appears within the ``filename``. Otherwise, multiple ranks may attempt to write to the same file(s), leading to corrupted checkpoints. If no tarball file extension is specified, ``.tar`` will be used. * To write to compressed tar files (regardless of whether DeepSpeed is enabled), set the file extension to ``'.tar.gz'``, ``'.tgz'``, ``'.tar.bz2'``, or ``'.tar.lzma'`` (depending on the desired compression algorithm). * To write to compressed pt files (when DeepSpeed is disabled), set the file extension to ``'.pt.bz2'``, ``'.pt.gz'``, ``'.pt.lz4'``, ``'.pt.lzma'``, ``'.pt.lzo'``, ``'.pt.xz'``, ``'.pt.zst'`` (depending on the desired algorithm). You must have the corresponding CLI tool installed. ``lz4`` is a good choice for a modest space saving while being very fast to compress. .. warning:: Using compression will block the training loop while checkpoints are being compressed and the compressibility of checkpoints can vary significantly depending on your setup. As such, we recommend saving checkpoints without compression by default. If you have the ``lz4`` command available on your system, you may want to try saving as ``.pt.lz4`` as the overhead is minimal (usually less than a second) and the saved space can sometimes be significant (1% - 40%). Consider the following scenario, where: * The default ``name='ep{{epoch}}-ba{{batch}}-rank{{rank}}'`` is used. * The current epoch count is ``1``. * The current batch count is ``42``. When DeepSpeed is not being used, the rank zero process will save the checkpoint to ``'ep1-ba42-rank0'``. When DeepSpeed is being used, each rank (process) will save checkpoints to:: ep1-ba42-rank0.tar ep1-ba42-rank1.tar ep1-ba42-rank2.tar ... weights_only (bool, optional): If ``True``, save only the model weights instead of the entire training state. (default: ``False``) .. note:: When using DeepSpeed, this parameter must be ``False``. Weights-only checkpointing is not currently compatible with DeepSpeed, Returns: list[pathlib.Path]: The list of checkpoint files saved, indexed by the rank of the process. .. note:: When using DeepSpeed, each process (rank) saves its own checkpoint file. When doing multi-node training, the filepaths are valid only on each process's node; Composer does not move checkpoint files between nodes. Otherwise, when not using DeepSpeed, each list will contain only one filepath, since only the rank zero process saves checkpoints. """