load_checkpoint#
- composer.utils.load_checkpoint(path, state, logger, object_store=None, load_weights_only=False, strict_model_weights=True, progress_bar=True, ignore_keys=None, exclude_algorithms=None, algorithm_passes=None)[source]#
Load a checkpoint from a local file, URI, or cloud object store into
state
.- Parameters
path (str) โ
The path format string to an existing checkpoint file.
It can be a path to a file on the local disk, a URL, or if
object_store
is set, the object name for a checkpoint in a cloud bucket.When using Deepspeed ZeRO, checkpoints are sharded by rank. Instead of hard-coding the rank in the
path
, use the following format variables:Variable
Description
{rank}
The global rank, as returned by
get_global_rank()
.{local_rank}
The local rank of the process, as returned by
get_local_rank()
.{node_rank}
The node rank, as returned by
get_node_rank()
.For example, suppose that checkpoints are stored in the following structure:
my_model/ep1-rank0.tar my_model/ep1-rank1.tar my_model/ep1-rank2.tar ...
Then,
path
should be set tomy_model/ep1-rank{rank}.tar
, and all ranks will load the correct state.state (State) โ The
State
to load the checkpoint into.logger (Logger) โ The
Logger
to log any information.object_store (Union[ObjectStore, LoggerDestination], optional) โ If the
path
is in an object store (i.e. AWS S3 or Google Cloud Storage), an instance ofObjectStore
orLoggerDestination
which will be used to retrieve the checkpoint. Otherwise, if the checkpoint is a local filepath, set toNone
. (default:None
)load_weights_only (bool, optional) โ Whether or not to only restore the model weights from the checkpoint without restoring the associated state. (default:
False
)strict_model_weights (bool, optional) โ Whether or not to force that the checkpointed weights must exactly match the model weights. (default:
True
)progress_bar (bool, optional) โ Whether or not to show a progress bar when downloading checkpoints. Ignored if the checkpoint is a local file path. (default:
True
)ignore_keys (list[str] | (dict) -> None, optional) โ
A list of paths for the
state_dict
of the checkpoint, which, when provided, will be ignored from the state_dict before a checkpoint is loaded. Each path is a list of strings specifying the keys to index intostate_dict
joined together with / as a separator (as PyTorch uses . in parameter names). If a prefix is provided, all children are also ignored (see Example 2). Seecomposer.core.state
for the structure of state_dict.Example 1:
ignore_keys = ["state/model/layer1.weights", "state/model/layer1.bias"]
would ignore layer 1 weights and bias.Example 2:
ignore_keys = ["state/model/*"]
would ignore the entire model, which would have the same effect as the previous example if there was only 1 layer.Example 3:
ignore_keys = ["state/model/layer*.weights"]
would ignore all weights in the model.Example 4:
ignore_keys = ["state/rank_zero_seed", "rng"]
would reset all randomness when loading the checkpoint.If a callable, it should take one argument which is the state_dict. The callable is free to arbitrarily modify the state_dict before it is loaded.
(default:
None
)exclude_algorithms (list[str], optional) โ
A list of algorithm names to exclude from loading. By default, algorithms with required_on_load=True which were enabled when training the loaded checkpoint are automatically applied unless they conflict with a user specified algorithm. These algorithms often change the model, and not applying them could result in certain layers not having weights loaded.
Example 1:
exclude_algorithms = ["BlurPool"]
would exclude BlurPool from loading.Example 2:
exclude_algorithms = ["FusedLayerNorm", "Alibi"]
would exclude FusedLayerNorm and Alibi from loading.(default:
None
)algorithm_passes (list[AlgorithmPass], optional) โ A list of algorithm passes to apply to autoloaded algorithms to sort them into the correct order. (default:
None
)
- Returns
Optional[list[dict[str, Any]]] โ The RNG state dicts, indexed by global rank, if
load_weights_only
is not None. Otherwise, None.