State#
- class composer.State(model, rank_zero_seed, run_name, device, max_duration=None, device_train_microbatch_size=None, auto_microbatching=False, train_dataloader=None, evaluators=None, dataloader=None, dataloader_label=None, dataloader_len=- 1, dataset_state=None, dataset_resumption=None, precision=Precision.FP32, precision_config=None, optimizers=None, scaler=None, save_metrics=False, algorithms=None, callbacks=None, parallelism_config=None)[source]#
The state of the trainer.
Contains variables that the trainer tracks throughout the training loop. Note that all the necessary parts (i.e.,
serialized_attributes
) of state are serialized when the trainer is checkpointed so that it can be used to restore the trainer and continue training from a checkpoint.algorithms
are able to modify an instance of this class in-place.Note
An instance of this class is automatically constructed by the
Trainer
constructor. A user need not instantiate this class.- Parameters
model (Module) โ The model, typically as a subclass of
ComposerModel
.rank_zero_seed (int) โ The seed used on the rank zero process. It is assumed that each rankโs seed is
rank_zero_seed + dist.get_global_rank()
.run_name (str) โ The name for this training run.
device (Device) โ The device used by this process. The trainer moves the model and loaded data to this device.
device_train_microbatch_size (int | float, optional) โ The microbatch size for each device during training.
auto_microbatching (bool, optional) โ Whether automatic microbatching is enabled.
train_dataloader (Iterable, optional) โ Dataloader used for training
evaluators (Evaluator | Evaluators, optional) โ
Evaluator
used for evaluation.dataloader (Iterable, optional) โ The active DataLoader.
dataloader_len (int | Time[int], optional) โ The number of batches per dataloader iteration (e.g. epoch). The trainer will yield the first
dataloader_len
batches per iteration. If-1
(the default), the entire dataloader will be iterated over.dataloader_label (str, optional) โ
The name for the dataloader. Required if
dataloader
is specified. (default:None
)By convention, the training dataloader is called
'train'
. The evaluator dataloader is called'eval'
, or when multiple evaluators are used, the name of the evaluator.dataset_state (dict[str, Any], optional) โ Mapping of dataset split to its iteration state for resumption.
dataset_resumption (dict[str, Any], optional) โ Mapping of dataset split to whether resumption is used.
max_duration (str | Time, optional) โ The maximum duration to train for. (default:
None
)precision (str | Precision) โ The numerical precision to use for training. See
Precision
for the supported precisions.precision_config (Optional[dict[str, Any]]) โ The config for FP8 scaling strategy. See parameters for DelayedScaling.
optimizers (Optimizer | Sequence[Optimizer], optional) โ The optimizer being used to train the model. Multiple optimizers are not currently supported.
schedulers (LRScheduler | Sequence[LRScheduler], optional) โ The learning rate scheduler (can also be a list or tuple of schedulers).
scaler (torch.amp.GradScaler, optional) โ The gradient scaler in use for mixed precision training.
save_metrics (bool, optional) โ Whether to save metrics in state_dict.
algorithms (Algorithm | Sequence[Algorithm], optional) โ The algorithms used for training.
callbacks (Callback | Sequence[Callback], optional) โ The callbacks used for training.
parallelism_config (ParallelismConfig, optional) โ The configuration dictionary for parallelism.
- batch#
The batch. This will be the entire batch during the
Event.AFTER_DATALOADER
, or a microbatch betweenEvent.BATCH_START
andEvent.BATCH_END
.- Type
types.Batch
- device#
The device used by this process. The trainer moves the model and loaded data to this device. This can be used in callbacks and algorithms to move data onto the correct device.
- Type
- train_metrics#
The current train metrics, organized by metric name.
train_metrics
will be deep-copied to ensure that each evaluator updates only itstrain_metrics
.For example:
>>> trainer = Trainer( ... ..., ... train_dataloader=train_dataloader, ... eval_dataloader=eval_dataloader, ... ) >>> trainer.fit() >>> trainer.state.train_metrics {'MulticlassAccuracy': MulticlassAccuracy()}
- eval_metrics#
The current evaluation metrics, organized by dataloader label and then by metric name. If not using an
Evaluator
, the eval dataloader is labeled'eval'
. Otherwise, in the case of having multiple evaluation datasets, the evaluator label is used. See the Multiple Datasets Documentation for more information.eval_metrics
will be deep-copied to ensure that each evaluator updates only itseval_metrics
.For example: >>> from composer.metrics import CrossEntropy >>> trainer = Trainer( โฆ โฆ, โฆ train_dataloader=train_dataloader, โฆ eval_dataloader=eval_dataloader, โฆ ) >>> trainer.fit() >>> trainer.state.eval_metrics {โevalโ: {โCrossEntropyโ: CrossEntropy(), โMulticlassAccuracyโ: MulticlassAccuracy()}}
Or, when using an
Evaluator
for multiple evaluation datasets:>>> from composer.core import Evaluator >>> trainer = Trainer( ... ..., ... train_dataloader=train_dataloader, ... eval_dataloader=[ ... Evaluator(label='eval1', dataloader=eval_1_dl, metric_names=['MulticlassAccuracy']), ... Evaluator(label='eval2', dataloader=eval_2_dl, metric_names=['MulticlassAccuracy']), ... ], ... ) >>> trainer.fit() >>> trainer.state.eval_metrics {'eval1': {'MulticlassAccuracy': MulticlassAccuracy()}, 'eval2': {'MulticlassAccuracy': MulticlassAccuracy()}}
- eval_timestamp#
The timestamp for the current evaluation dataloader. This timestamp is reset before the dataloader is evaluated. The
epoch
attribute for this timestamp is always0
.- Type
- model#
The training model.
Note
When using multi-rank training with DDP, the model will be wrapped with
DistributedDataParallel
.- Type
- outputs#
The most recently computed output from the modelโs forward pass.
- predict_timestamp#
The timestamp for the current prediction dataloader. This timestamp is reset before the dataloader is used. The
epoch
attribute for this timestamp is always0
.- Type
- scaler#
The gradient scaler if using mixed-precision training, or
None
if not using mixed-precision training.- Type
torch.amp.GradScaler
- serialized_attributes#
The names of the attribute which are serialized in a checkpoint.
By default, the following attributes are serialized:
Attribute
Description
model
The model under training.
optimizers
The optimizers being used to train the model.
schedulers
The learning rate schedulers.
algorithms
The algorithms used for training.
callbacks
The callbacks used for training.
scaler
The gradient scaler in use for mixed precision training.
timestamp
The timestamp that tracks training loop progress.
rank_zero_seed
The seed of the rank zero process.
train_metrics
The current training metrics
eval_metrics
The current evaluation metrics
run_name
The run name for training.
dataset_state
The dataset iteration state.
- property algorithms#
The algorithms.
- batch_get_item(key)[source]#
Gets element from batch either specified by key or user-specified function.
See batch_get in utils/batch_helpers.py for examples.
- Parameters
key (str | int | tuple[Callable, Callable] | Any, optional) โ A key to index into the batch or a user-specified function to do the extracting. A pair of callables is also supported for cases where a get and set function pair are both passed (like in Algorithms). The getter is assumed to be the first of the pair.
- Returns
The part of the batch specified by the key. This could be any type โ depending on what the batch is composed of.
- batch_set_item(key, value)[source]#
Sets the element specified by the key of the set_fn to the specified value.
This is not an in-place operation, as for tuple-typed batches, a new batch object must be created to modify them.
See batch_set in utils/batch_helpers.py for examples.
- Parameters
key (str | int | tuple[Callable, Callable] | Any, optional) โ A key to index into the batch or a user-specified function to do the setting. A pair of callables is also supported for cases where a get and set function pair are both passed (like in Algorithms). The setter is assumed to be the second of the pair.
value (Any) โ The value that batch[key] or batch.key gets set to or that the user-defined set function sets a part of the batch to.
- Returns
batch (Any) โ The updated batch with value set at key.
- property callbacks#
The callbacks.
- property dataloader#
The active dataloader.
- property dataloader_label#
The dataloader label for the active dataloader.
By default, the training dataloader is called
'train'
. The evaluator dataloader is called'eval'
, or when multiple evaluators are used, the name of the evaluator. However, the dataloader label can be explicitly specified inTrainer.fit()
andTrainer.eval()
.- Returns
Optional[str] โ The dataloader label, or None if no dataloader is set.
- property dataloader_len#
The number of batches per dataloader iteration (e.g. epoch), as used by the trainer.
Note
If not explicitly specified, this value is an approximation, as it depends on
len(self.dataloader)
. See the PyTorch DataLoader Documentation for more information.- Returns
Optional[Time[int]] โ The number of batches per dataloader iteration (e.g. epoch), or None if no dataloader
is defined or if the dataloader has an unknown length (e.g. streaming dataloaders)
- property evaluators#
The evaluators.
- property fsdp_enabled#
Indicates if FSDP is enabled.
- get_elapsed_duration()[source]#
Get the elapsed training duration.
- Returns
Optional[Time[float]] โ The elapsed duration, in
TimeUnit.DURATION
.Time(0.0, TimeUnit.DURATION)
represents the beginning of training andTime(1.0, TimeUnit.DURATION)
represents a completed training process. ReturnsNone
ifmax_duration
is None.
- get_model_state_dict()[source]#
Collect the state dict for the model.
- Returns
dict[str, Any] โ The state dict for the model.
- get_optim_state_dict()[source]#
Collect the state dict for the optimizer.
- Returns
dict[str, Any] โ The state dict for the optimizer.
- load_model_state(state_dict, logger, strict, exclude_algorithms=None, algorithm_passes=None)[source]#
Loads the modelโs state from a
state_dict
.- Parameters
state_dict (dict[str, Any]) โ The state dict, generated from a previous call to
state_dict()
.logger (Logger) โ The logger.
strict (bool) โ Whether the keys (i.e., model parameter names) in the model state dict should perfectly match the keys in the model instance.
exclude_algorithms (list[str], optional) โ list of algorithm names to exclude from autoloading. (default:
None
)algorithm_passes (list[AlgorithmPass], optional) โ A list of algorithm passes to apply to autoloaded algorithms to sort them into the correct order. (default:
None
)
- load_state_dict(state, logger, strict=False, exclude_algorithms=None, algorithm_passes=None)[source]#
Loads the state.
- Parameters
state (dict[str, Any]) โ object returned from call to
state_dict()
.logger (Logger) โ The logger.
strict (bool) โ whether the keys in the
state["model"]
should perfectly match the keys in theself.model
. Defaults to False.exclude_algorithms (list[str], optional) โ list of algorithm names to exclude from autoloading. (default:
None
)algorithm_passes (list[AlgorithmPass], optional) โ A list of algorithm passes to apply to autoloaded algorithms to sort them into the correct order. (default:
None
)
- property max_duration#
The maximum training duration.
- property optimizers#
The optimizers.
- property precision#
The numerical precision to use for training.
See
Precision
for the supported precisions.
- property precision_config#
The config for FP8 scaling strategy.
See parameters for DelayedScaling.
- property schedulers#
The schedulers.
- property seed#
The seed for the current rank.
- set_dataloader(dataloader=None, dataloader_label=None, dataloader_len=- 1)[source]#
Update the active dataloader and dataloader label.
- Parameters
dataloader (Iterable, optional) โ The dataloader. Defaults to None.
dataloader_label (str, optional) โ The dataloader label. Must be
None
if and only ifdataloader
is None. Defaults to None.dataloader_len (int, int) โ The number of batches per dataloader iteration (e.g. epoch), as used by the trainer. Set to
-1
to iterate over the entire dataset. (Default:-1
.)
- state_dict()[source]#
Collect the state dicts of our serializable attributes.
- Returns
dict[str, Any] โ The state dict.
- stop_training()[source]#
Gracefully stop training.
The current batch of training will finish, and any scheduled evaluation, logging, and evaluation for that batch, as well as any epoch end events.
- property train_dataloader#
Get the train dataloader.
- Returns
Iterable | DataLoader, optional โ The dataloader.