load_checkpoint#
- composer.utils.load_checkpoint(path, state, logger, object_store=None, load_weights_only=False, strict_model_weights=True, progress_bar=True, ignore_keys=None, exclude_algorithms=None, algorithm_passes=None)[source]#
Load a checkpoint from a local file, URI, or cloud object store into
state
.- Parameters
path (str) โ
The path format string to an existing checkpoint file.
It can be a path to a file on the local disk, a URL, or if
object_store
is set, the object name for a checkpoint in a cloud bucket.When using FSDP with sharded checkpointing, checkpoint files are sharded by rank, and
load_path
should be set to the directory containing sharded checkpoint files.state (State) โ The
State
to load the checkpoint into.logger (Logger) โ The
Logger
to log any information.object_store (Union[ObjectStore, LoggerDestination], optional) โ If the
path
is in an object store (i.e. AWS S3 or Google Cloud Storage), an instance ofObjectStore
orLoggerDestination
which will be used to retrieve the checkpoint. Otherwise, if the checkpoint is a local filepath, set toNone
. (default:None
)load_weights_only (bool, optional) โ Whether or not to only restore the model weights from the checkpoint without restoring the associated state. (default:
False
)strict_model_weights (bool, optional) โ Whether or not to force that the checkpointed weights must exactly match the model weights. (default:
True
)progress_bar (bool, optional) โ Whether or not to show a progress bar when downloading checkpoints. Ignored if the checkpoint is a local file path. (default:
True
)ignore_keys (list[str] | (dict) -> None, optional) โ
A list of paths for the
state_dict
of the checkpoint, which, when provided, will be ignored from the state_dict before a checkpoint is loaded. Each path is a list of strings specifying the keys to index intostate_dict
joined together with / as a separator (as PyTorch uses . in parameter names). If a prefix is provided, all children are also ignored (see Example 2). Seecomposer.core.state
for the structure of state_dict.Example 1:
ignore_keys = ["state/model/layer1.weights", "state/model/layer1.bias"]
would ignore layer 1 weights and bias.Example 2:
ignore_keys = ["state/model/*"]
would ignore the entire model, which would have the same effect as the previous example if there was only 1 layer.Example 3:
ignore_keys = ["state/model/layer*.weights"]
would ignore all weights in the model.Example 4:
ignore_keys = ["state/rank_zero_seed", "rng"]
would reset all randomness when loading the checkpoint.If a callable, it should take one argument which is the state_dict. The callable is free to arbitrarily modify the state_dict before it is loaded.
(default:
None
)exclude_algorithms (list[str], optional) โ
A list of algorithm names to exclude from loading. By default, algorithms with required_on_load=True which were enabled when training the loaded checkpoint are automatically applied unless they conflict with a user specified algorithm. These algorithms often change the model, and not applying them could result in certain layers not having weights loaded.
Example 1:
exclude_algorithms = ["BlurPool"]
would exclude BlurPool from loading.Example 2:
exclude_algorithms = ["FusedLayerNorm", "Alibi"]
would exclude FusedLayerNorm and Alibi from loading.(default:
None
)algorithm_passes (list[AlgorithmPass], optional) โ A list of algorithm passes to apply to autoloaded algorithms to sort them into the correct order. (default:
None
)
- Returns
Optional[list[dict[str, Any]]] โ The RNG state dicts, indexed by global rank, if
load_weights_only
is not None. Otherwise, None.