load_checkpoint#
- composer.utils.load_checkpoint(path, state, logger, object_store=None, load_weights_only=False, strict_model_weights=True, progress_bar=True, ignore_keys=None, exclude_algorithms=None, algorithm_passes=None)[source]#
Load a checkpoint from a local file, URI, or cloud object store into
state.- Parameters
path (str) โ
The path format string to an existing checkpoint file.
It can be a path to a file on the local disk, a URL, or if
object_storeis set, the object name for a checkpoint in a cloud bucket.When using FSDP with sharded checkpointing, checkpoint files are sharded by rank, and
load_pathshould be set to the directory containing sharded checkpoint files.state (State) โ The
Stateto load the checkpoint into.logger (Logger) โ The
Loggerto log any information.object_store (Union[ObjectStore, LoggerDestination], optional) โ If the
pathis in an object store (i.e. AWS S3 or Google Cloud Storage), an instance ofObjectStoreorLoggerDestinationwhich will be used to retrieve the checkpoint. Otherwise, if the checkpoint is a local filepath, set toNone. (default:None)load_weights_only (bool, optional) โ Whether or not to only restore the model weights from the checkpoint without restoring the associated state. (default:
False)strict_model_weights (bool, optional) โ Whether or not to force that the checkpointed weights must exactly match the model weights. (default:
True)progress_bar (bool, optional) โ Whether or not to show a progress bar when downloading checkpoints. Ignored if the checkpoint is a local file path. (default:
True)ignore_keys (list[str] | (dict) -> None, optional) โ
A list of paths for the
state_dictof the checkpoint, which, when provided, will be ignored from the state_dict before a checkpoint is loaded. Each path is a list of strings specifying the keys to index intostate_dictjoined together with / as a separator (as PyTorch uses . in parameter names). If a prefix is provided, all children are also ignored (see Example 2). Seecomposer.core.statefor the structure of state_dict.Example 1:
ignore_keys = ["state/model/layer1.weights", "state/model/layer1.bias"]would ignore layer 1 weights and bias.Example 2:
ignore_keys = ["state/model/*"]would ignore the entire model, which would have the same effect as the previous example if there was only 1 layer.Example 3:
ignore_keys = ["state/model/layer*.weights"]would ignore all weights in the model.Example 4:
ignore_keys = ["state/rank_zero_seed", "rng"]would reset all randomness when loading the checkpoint.If a callable, it should take one argument which is the state_dict. The callable is free to arbitrarily modify the state_dict before it is loaded.
(default:
None)exclude_algorithms (list[str], optional) โ
A list of algorithm names to exclude from loading. By default, algorithms with required_on_load=True which were enabled when training the loaded checkpoint are automatically applied unless they conflict with a user specified algorithm. These algorithms often change the model, and not applying them could result in certain layers not having weights loaded.
Example 1:
exclude_algorithms = ["BlurPool"]would exclude BlurPool from loading.Example 2:
exclude_algorithms = ["FusedLayerNorm", "Alibi"]would exclude FusedLayerNorm and Alibi from loading.(default:
None)algorithm_passes (list[AlgorithmPass], optional) โ A list of algorithm passes to apply to autoloaded algorithms to sort them into the correct order. (default:
None)
- Returns
Optional[list[dict[str, Any]]] โ The RNG state dicts, indexed by global rank, if
load_weights_onlyis not None. Otherwise, None.