load_checkpoint#
- composer.utils.load_checkpoint(path, state, logger, object_store=None, load_weights_only=False, strict_model_weights=False, progress_bar=True, ignore_keys=None, exclude_algorithms=None, algorithm_passes=None)[source]#
Load a checkpoint from a local file, URI, or cloud object store into
state.- Parameters
path (str) โ
The path format string to an existing checkpoint file.
It can be a path to a file on the local disk, a URL, or if
object_storeis set, the object name for a checkpoint in a cloud bucket.When using Deepspeed ZeRO, checkpoints are sharded by rank. Instead of hard-coding the rank in the
path, use the following format variables:Variable
Description
{rank}The global rank, as returned by
get_global_rank().{local_rank}The local rank of the process, as returned by
get_local_rank().{node_rank}The node rank, as returned by
get_node_rank().For example, suppose that checkpoints are stored in the following structure:
my_model/ep1-rank0.tar my_model/ep1-rank1.tar my_model/ep1-rank2.tar ...
Then,
pathshould be set tomy_model/ep1-rank{rank}.tar, and all ranks will load the correct state.state (State) โ The
Stateto load the checkpoint into.logger (Logger) โ The
Loggerto log any information.object_store (Union[ObjectStore, LoggerDestination], optional) โ If the
pathis in an object store (i.e. AWS S3 or Google Cloud Storage), an instance ofObjectStoreorLoggerDestinationwhich will be used to retrieve the checkpoint. Otherwise, if the checkpoint is a local filepath, set toNone. (default:None)load_weights_only (bool, optional) โ Whether or not to only restore the model weights from the checkpoint without restoring the associated state. (default:
False)strict_model_weights (bool, optional) โ Whether or not to force that the checkpointed weights must exactly match the model weights. (default:
False)progress_bar (bool, optional) โ Whether or not to show a progress bar when downloading checkpoints. Ignored if the checkpoint is a local file path. (default:
True)ignore_keys (List[str] | (Dict) -> None, optional) โ
A list of paths for the
state_dictof the checkpoint, which, when provided, will be ignored from the state_dict before a checkpoint is loaded. Each path is a list of strings specifying the keys to index intostate_dictjoined together with / as a separator (as PyTorch uses . in parameter names). If a prefix is provided, all children are also ignored (see Example 2). Seecomposer.core.statefor the structure of state_dict.Example 1:
ignore_keys = ["state/model/layer1.weights", "state/model/layer1.bias"]would ignore layer 1 weights and bias.Example 2:
ignore_keys = ["state/model/*"]would ignore the entire model, which would have the same effect as the previous example if there was only 1 layer.Example 3:
ignore_keys = ["state/model/layer*.weights"]would ignore all weights in the model.Example 4:
ignore_keys = ["state/rank_zero_seed", "rng"]would reset all randomness when loading the checkpoint.If a callable, it should take one argument which is the state_dict. The callable is free to arbitrarily modify the state_dict before it is loaded.
(default:
None)exclude_algorithms (List[str], optional) โ
A list of algorithm names to exclude from loading. By default, algorithms with required_on_load=True which were enabled when training the loaded checkpoint are automatically applied unless they conflict with a user specified algorithm. These algorithms often change the model, and not applying them could result in certain layers not having weights loaded.
Example 1:
exclude_algorithms = ["BlurPool"]would exclude BlurPool from loading.Example 2:
exclude_algorithms = ["FusedLayerNorm", "Alibi"]would exclude FusedLayerNorm and Alibi from loading.(default:
None)algorithm_passes (List[AlgorithmPass], optional) โ A list of algorithm passes to apply to autoloaded algorithms to sort them into the correct order. (default:
None)
- Returns
Optional[List[Dict[str, Any]]] โ The RNG state dicts, indexed by global rank, if
load_weights_onlyis not None. Otherwise, None.