Trainer#
- class composer.Trainer(*, model, train_dataloader=None, train_dataloader_label='train', train_subset_num_batches=- 1, spin_dataloaders=True, max_duration=None, algorithms=None, algorithm_passes=None, optimizers=None, schedulers=None, scale_schedule_ratio=1.0, step_schedulers_every_batch=None, eval_dataloader=None, eval_interval=1, eval_subset_num_batches=- 1, callbacks=None, loggers=None, run_name=None, progress_bar=True, log_to_console=False, console_stream='stderr', console_log_interval='1ba', log_traces=False, auto_log_hparams=False, load_path=None, load_object_store=None, load_weights_only=False, load_strict_model_weights=True, load_progress_bar=True, load_ignore_keys=None, load_exclude_algorithms=None, save_folder=None, save_filename='ep{epoch}-ba{batch}-rank{rank}.pt', save_latest_filename='latest-rank{rank}.pt', save_overwrite=False, save_interval='1ep', save_weights_only=False, save_ignore_keys=None, save_num_checkpoints_to_keep=- 1, save_metrics=False, autoresume=False, deepspeed_config=None, parallelism_config=None, device=None, precision=None, precision_config=None, device_train_microbatch_size=None, accumulate_train_batch_on_tokens=False, seed=None, deterministic_mode=False, dist_timeout=300.0, ddp_sync_strategy=None, profiler=None, python_log_level=None, compile_config=None)[source]#
Train models with Composer algorithms.
The trainer supports models with
ComposerModel
instances. TheTrainer
is highly customizable and can support a wide variety of workloads. See the training guide for more information.Example
Train a model and save a checkpoint:
import os from composer import Trainer ### Create a trainer trainer = Trainer( model=model, train_dataloader=train_dataloader, max_duration="1ep", eval_dataloader=eval_dataloader, optimizers=optimizer, schedulers=scheduler, device="cpu", eval_interval="1ep", save_folder="checkpoints", save_filename="ep{epoch}.pt", save_interval="1ep", save_overwrite=True, ) # Fit and run evaluation for 1 epoch. # Save a checkpoint after 1 epoch as specified during trainer creation. trainer.fit()
Load the checkpoint and resume training:
# Get the saved checkpoint filepath checkpoint_path = trainer.saved_checkpoints.pop() # Create a new trainer with the `load_path` argument set to the checkpoint path. trainer = Trainer( model=model, train_dataloader=train_dataloader, max_duration="2ep", eval_dataloader=eval_dataloader, optimizers=optimizer, schedulers=scheduler, device="cpu", eval_interval="1ep", load_path=checkpoint_path, ) # Continue training and running evaluation where the previous trainer left off # until the new max_duration is reached. # In this case it will be one additional epoch to reach 2 epochs total. trainer.fit()
- Parameters
model (ComposerModel) โ
The model to train. Can be user-defined or one of the models included with Composer.
See also
composer.models
for models built into Composer.train_dataloader (Iterable | DataSpec | dict, optional) โ
The dataloader,
DataSpec
, or dict ofDataSpec
kwargs for the training data. In order to specify custom preprocessing steps on each data batch, specify aDataSpec
instead of a dataloader. It is recommended that the dataloader, whether specified directly or as part of aDataSpec
, should be atorch.utils.data.DataLoader
.Note
The
train_dataloader
should yield per-rank batches. Each per-rank batch will then be further divided based on thedevice_train_microbatch_size
parameter. For example, if the desired optimization batch size is2048
and training is happening across 8 GPUs, then eachtrain_dataloader
should yield a batch of size2048 / 8 = 256
. Ifdevice_train_microbatch_size = 128
, then the per-rank batch will be divided into256 / 128 = 2
microbatches of size128
.If
train_dataloader
is not specified when constructing the trainer, it must be specified when invokingTrainer.fit()
.train_dataloader_label (str, optional) โ
The label for the train dataloader. (default:
'train'
)This label is used to index the training metrics in
State.train_metrics
.This parameter has no effect if
train_dataloader
is not specified.train_subset_num_batches (int, optional) โ
If specified, finish every epoch early after training on this many batches. This parameter has no effect if it is greater than
len(train_dataloader)
. If-1
, then the entire dataloader will be iterated over. (default:-1
)When using the profiler, it can be helpful to set this parameter to the length of the profile schedule. This setting will end each epoch early to avoid additional training that will not be profiled.
This parameter is ignored if
train_dataloader
is not specified.spin_dataloaders (bool, optional) โ
If
True
, dataloaders will be spun up to the current timestamp by skipping samples which have already been trained on. If a dataloader has a way to resume from the current batch without spinning, this will be a no-op. This ensures dataloaders continue from the same batch when resuming training. (default:True
)Note
Spinning dataloaders can be potentially very slow but is required to skip samples which have already been trained on. If it is acceptable to repeat samples when resuming training, it is possible to resume faster by setting
spin_dataloaders=False
. This may have severe performance implications and is generally not recommended unless you confidently understand the implications.max_duration (Time | str | int, optional) โ
The maximum duration to train. Can be an integer, which will be interpreted to be epochs, a str (e.g.
1ep
, or10ba
), or aTime
object.If
max_duration
is not specified when constructing the trainer,duration
must be specified when invokingTrainer.fit()
.algorithms (Algorithm | Sequence[Algorithm], optional) โ
The algorithms to use during training. If
None
, then no algorithms will be used. (default:None
)See also
composer.algorithms
for the different algorithms built into Composer.algorithm_passes ([AlgorithmPass | tuple[AlgorithmPass, int] | Sequence[AlgorithmPass | tuple[AlgorithmPass, int]], optional) โ
Optional list of passes to change order in which algorithms are applied. These passes are merged with the default passes specified in
Engine
. IfNone
, then no additional passes will be used. (default:None
)See also
composer.core.Engine
for more information.optimizers (Optimizer, optional) โ
The optimizer. If
None
, will be set toDecoupledSGDW(model.parameters(), lr=0.1)
. (default:None
)See also
composer.optim
for the different optimizers built into Composer.schedulers (LRScheduler | ComposerScheduler | Sequence[LRScheduler | ComposerScheduler], optional) โ
The learning rate schedulers. If
[]
orNone
, the learning rate will be constant. (default:None
).See also
composer.optim.scheduler
for the different schedulers built into Composer.scale_schedule_ratio (float, optional) โ
Ratio by which to scale the training duration and learning rate schedules. (default:
1.0
)E.g.,
0.5
makes the schedule take half as many epochs and2.0
makes it take twice as many epochs.1.0
means no change.This parameter has no effect if
schedulers
is not specified.Note
Training for less time, while rescaling the learning rate schedule, is a strong baseline approach to speeding up training. E.g., training for half duration often yields minor accuracy degradation, provided that the learning rate schedule is also rescaled to take half as long.
To see the difference, consider training for half as long using a cosine annealing learning rate schedule. If the schedule is not rescaled, training ends while the learning rate is still ~0.5 of the initial LR. If the schedule is rescaled with
scale_schedule_ratio
, the LR schedule would finish the entire cosine curve, ending with a learning rate near zero.step_schedulers_every_batch (bool, optional) โ By default, native PyTorch schedulers are updated every epoch, while Composer Schedulers are updated every step. Setting this to
True
will force schedulers to be stepped every batch, whileFalse
means schedulers stepped every epoch.None
indicates the default behavior. (default:None
)eval_dataloader (Iterable | DataLoader | DataSpec | Evaluator | Sequence[Evaluator], optional) โ
The
Iterable
,DataLoader
,DataSpec
,Evaluator
, or sequence of evaluators for the evaluation data.To evaluate one or more specific metrics across one or more datasets, pass in an
Evaluator
. If aDataLoader
,DataSpec
, orIterable
is passed in, then all metrics returned bymodel.get_metrics()
will be used during evaluation. If aEvaluator
is specified in a list, all eval dataloaders must beEvaluator
instances.None
results in no evaluation. (default:None
)eval_interval (int | str | Time | (State, Event) -> bool, optional) โ
Specifies how frequently to run evaluation. An integer, which will be interpreted to be epochs, a str (e.g.
1ep
, or10ba
), aTime
object, or a callable. Defaults to1
(evaluate every epoch).If an integer (in epochs),
Time
string, orTime
instance, the evaluator will be run with this frequency.Time
strings orTime
instances must have units ofTimeUnit.BATCH
orTimeUnit.EPOCH
.Set to
0
to disable evaluation.If a callable, it should take two arguments (
State
,Event
) and return a bool representing whether the evaluator should be invoked. The event will be eitherEvent.BATCH_END
orEvent.EPOCH_END
.This
eval_interval
will apply to anyEvaluator
ineval_dataloader
that does not specify aneval_interval
or if a dataloader is passed in directly. This parameter has no effect ifeval_dataloader
is not specified.When specifying time string or integer for the
eval_interval
, the evaluator(s) are also run at theEvent.FIT_END
if it doesnโt evenly divide the training duration.eval_subset_num_batches (int, optional) โ
If specified, evaluate on this many batches. Defaults to
-1
, which means to iterate over the entire dataloader.This parameter has no effect if
eval_dataloader
is not specified, it is greater thanlen(eval_dataloader)
, oreval_dataloader
is anEvaluator
andsubset_num_batches
was specified as part of theEvaluator
.callbacks (Callback | Sequence[Callback], optional) โ
The callbacks to run during training. If
None
, then no callbacks will be run. (default:None
).See also
composer.callbacks
for the different callbacks built into Composer.loggers (LoggerDestination | Sequence[LoggerDestination], optional) โ
The destinations to log training information to.
See also
composer.loggers
for the different loggers built into Composer.run_name (str, optional) โ A name for this training run. If not specified, the env var COMPOSER_RUN_NAME or RUN_NAME will be used if set. Otherwise, the timestamp will be combined with a coolname, e.g.
1654298855-electric-zebra
.progress_bar (bool) โ Whether to show a progress bar. (default:
True
)log_to_console (bool) โ Whether to print logging statements to the console. (default:
False
)console_stream (TextIO | str, optional) โ The stream to write to. If a string, it can either be
'stdout'
or'stderr'
. (default:sys.stderr
)console_log_interval (int | str | Time, optional) โ
Specifies how frequently to log metrics to console. An integer, which will be interpreted to be epochs, a str (e.g.
1ep
, or10ba
), aTime
object, or a callable. (default:1ba
) Defaults to1ba
(log metrics every batch).If an integer (in epochs),
Time
string, orTime
instance, the metrics will be logged with this frequency.Time
strings orTime
instances must have units ofTimeUnit.BATCH
orTimeUnit.EPOCH
.Set to
0
to disable metrics logging to console.log_traces (bool) โ Whether to log traces or not. (default:
False
)auto_log_hparams (bool) โ Whether to automatically extract hyperparameters. (default:
False
)load_path (str, optional) โ
The path format string to an existing checkpoint file.
It can be a path to a file on the local disk, a URL, or if
load_object_store
is set, the object name for a checkpoint in a cloud bucket. If a URI is specified,load_object_store
does not need to be set.When using Deepspeed ZeRO, checkpoints are sharded by rank. Instead of hard-coding the rank in the
path
, use the following format variables:Variable
Description
{rank}
The global rank, as returned by
get_global_rank()
.{local_rank}
The local rank of the process, as returned by
get_local_rank()
.{node_rank}
The node rank, as returned by
get_node_rank()
.For example, suppose that checkpoints are stored in the following structure:
my_model/ep1-rank0.tar my_model/ep1-rank1.tar my_model/ep1-rank2.tar ...
Then,
load_path
should be set tomy_model/ep1-rank{rank}.tar
, and all ranks will load the correct state.If
None
then no checkpoint will be loaded. (default:None
)load_object_store (Union[ObjectStore, LoggerDestination], optional) โ
If the
load_path
is in an object store (i.e. AWS S3 or Google Cloud Storage), an instance ofObjectStore
orLoggerDestination
which will be used to retrieve the checkpoint. Otherwise, if the checkpoint is a local filepath, set toNone
. Also, it can beNone
if theload_path
is an S3 URI because the appropriate object store will be automatically constructed in that case. Ignored ifload_path
isNone
. (default:None
)Example:
from composer import Trainer from composer.utils import LibcloudObjectStore # Create the object store provider with the specified credentials creds = {"key": "object_store_key", "secret": "object_store_secret"} store = LibcloudObjectStore(provider="s3", container="my_container", provider_kwargs=creds) checkpoint_path = "./path_to_the_checkpoint_in_object_store" # Create a trainer which will load a checkpoint from the specified object store trainer = Trainer( model=model, train_dataloader=train_dataloader, max_duration="10ep", eval_dataloader=eval_dataloader, optimizers=optimizer, schedulers=scheduler, device="cpu", eval_interval="1ep", load_path=checkpoint_path, load_object_store=store, )
load_weights_only (bool, optional) โ Whether or not to only restore the weights from the checkpoint without restoring the associated state. Ignored if
load_path
isNone
. (default:False
)load_strict_model_weights (bool, optional) โ Ensure that the set of weights in the checkpoint and model must exactly match. Ignored if
load_path
isNone
. (default:True
)load_progress_bar (bool, optional) โ Display the progress bar for downloading the checkpoint. Ignored if
load_path
is eitherNone
or a local file path. (default:True
)load_ignore_keys (list[str] | (dict) -> None, optional) โ
A list of paths for the
state_dict
of the checkpoint, which, when provided, will be ignored from the state_dict before a checkpoint is loaded. Each path is a list of strings specifying the keys to index intostate_dict
joined together with / as a separator (as PyTorch uses . in parameter names). If a prefix is provided, all children are also ignored (see Example 2). Seecomposer.core.state
for the structure of state_dict.Example 1:
load_ignore_keys = ["state/model/layer1.weights", "state/model/layer1.bias"]
would ignore layer 1 weights and bias.Example 2:
load_ignore_keys = ["state/model/*"]
would ignore the entire model, which would have the same effect as the previous example if there was only 1 layer.Example 3:
load_ignore_keys = ["state/model/layer*.weights"]
would ignore all weights in the model.Example 4:
load_ignore_keys = ["state/rank_zero_seed", "rng"]
would reset all randomness when loading the checkpoint.If a callable, it should take one argument which is the state_dict. The callable is free to arbitrarily modify the state_dict before it is loaded.
(default:
None
)load_exclude_algorithms (list[str], optional) โ
A list of algorithm names to exclude from loading. By default, algorithms with required_on_load=True which were enabled when training the loaded checkpoint are automatically applied unless they conflict with a user specified algorithm. These algorithms often change the model, and not applying them could result in certain layers not having weights loaded.
Example 1:
load_exclude_algorithms = ["BlurPool"]
would exclude BlurPool from loading.Example 2:
load_exclude_algorithms = ["FusedLayerNorm", "Alibi"]
would exclude FusedLayerNorm and Alibi from loading.(default:
None
)save_folder (str, optional) โ
Format string for the folder where checkpoints are saved. If
None
, checkpoints will not be saved. Can also be a URI for S3 paths only. In the case of an S3 URI, the appropriate ~.RemoteUploader object will be created automatically. (default:None
)See also
Note
For fine-grained control on checkpoint saving (e.g. to save different types of checkpoints at different intervals), leave this parameter as
None
, and instead pass instance(s) ofCheckpointSaver
directly ascallbacks
.save_filename (str, optional) โ
A format string describing how to name checkpoints. This parameter has no effect if
save_folder
isNone
. (default:"ep{epoch}-ba{batch}-rank{rank}.pt"
)See also
save_latest_filename (str, optional) โ
A format string for the name of a symlink (relative to
save_folder
) that points to the last saved checkpoint. This parameter has no effect ifsave_folder
isNone
. To disable symlinking, set this toNone
. (default:"latest-rank{rank}.pt"
)See also
save_overwrite (bool, optional) โ
Whether existing checkpoints should be overridden. This parameter has no effect if
save_folder
is None. (default:False
)See also
save_interval (Time | str | int | (State, Event) -> bool) โ
A
Time
, time-string, integer (in epochs), or a function that takes (state, event) and returns a boolean whether a checkpoint should be saved. This parameter has no effect ifsave_folder
isNone
. (default:'1ep'
)See also
save_weights_only (bool, optional) โ
Whether to save only the model weights instead of the entire training state. This parameter has no effect if
save_folder
isNone
. (default:False
)See also
save_ignore_keys (list[str] | (dict) -> None, optional) โ
A list of paths for the
state_dict
of the checkpoint, which, when provided, will be ignored from the state_dict before a checkpoint is saved. Each path is a list of strings specifying the keys to index intostate_dict
joined together with / as a separator (as PyTorch uses . in parameter names). If a prefix is provided, all children are also ignored (see Example 2). Seecomposer.core.state
for the structure of state_dict.Example 1:
save_ignore_keys = ["state/model/layer1.weights", "state/model/layer1.bias"]
would ignore layer 1 weights and bias.Example 2:
save_ignore_keys = ["state/model/*"]
would ignore the entire model, which would have the same effect as the previous example if there was only 1 layer.Example 3:
save_ignore_keys = ["state/model/layer*.weights"]
would ignore all weights in the model.Example 4:
save_ignore_keys = ["state/rank_zero_seed", "rng"]
would reset all randomness when saving the checkpoint.If a callable, it should take one argument which is the state_dict. The callable is free to arbitrarily modify the state_dict before it is loaded.
(default:
None
)save_num_checkpoints_to_keep (int, optional) โ
The number of checkpoints to keep locally. The oldest checkpoints are removed first. Set to
-1
to keep all checkpoints locally. (default:-1
)Checkpoints will be removed after they have been uploaded. For example, when this callback is used in conjunction with the
RemoteUploaderDownloader
, set this parameter to0
to immediately delete checkpoints from the local disk after they have been uploaded to the object store.This parameter only controls how many checkpoints are kept locally; checkpoints are not deleted from remote file systems.
save_metrics (bool, optional) โ Whether to save the metrics. By default, metrics are not saved to checkpoint as state usually does not need to be preserved and inconsistent state can cause issues when loading. (default:
False
)autoresume (bool, optional) โ
Whether or not to enable autoresume, which allows for stopping and resuming training. This allows use of spot instances, as the training run is now fault tolerant. This parameter requires
save_folder
andrun_name
to be specified. (default:False
)When enabled, the save_folder is checked for checkpoints of the format
"{save_folder}/{save_latest_filename}"
, which are loaded to continue training. If no local checkpoints are found, each logger is checked for potential remote checkpoints named"{save_folder}/{save_latest_filename}"
. Finally, if no logged checkpoints are found,load_path
is used to load a checkpoint if specified. This should only occur at the start of a run using autoresume.For example, to run a fine-tuning run on a spot instance,
load_path
would be set to the original weights and an object store logger would be added. In the original run,load_path
would be used to get the starting checkpoint. For any future restarts, such as due to the spot instance being killed, the loggers would be queried for the latest checkpoint the object store logger would be downloaded and used to resume training.deepspeed_config (dict[str, Any], optional) โ
Configuration for DeepSpeed, formatted as a JSON according to DeepSpeedโs documentation. (default:
None
)To use DeepSpeed with default values, set to the empty dictionary
{}
. To disable DeepSpeed (the default), set toNone
.parallelism_config (Union[dict[str, Any], ParallelismConfig], optional) โ
Configuration for parallelism options. Currently supports fsdp and tensor parallelism, whose respective configs are specified as the keys
fsdp
andtp
. (default:None
)- For parallelism_config[โfsdpโ], see FSDP Documentation
for more details. To use FSDP with default values, set to the empty dictionary
{}
. To disable FSDP, set toNone
or remove the key from the dictionary.- For parallelism_config[โtpโ], see TP Documentation
for more details. To use Tensor Parallelism with default values, set to the empty dictionary
{}
. To disable Tensor Parallelism, set toNone
or remove the key from the dictionary.
Note
This parameter is experimental and subject to change without standard deprecation cycles.
device (Device | str, optional) โ
The device to use for training, which can be
'cpu'
,'gpu'
,'tpu'
, or'mps'
. (default:None
)The default behavior sets the device to
'gpu'
if CUDA is available, and otherwise'cpu'
.precision (Precision | str, optional) โ Numerical precision to use for training. One of
fp32
,amp_bf16
oramp_fp16
(recommended). (default:Precision.FP32
if training on CPU;Precision.AMP_FP16
if training on GPU)precision_config (Optional[dict[str, Any]]) โ The config for FP8 scaling strategy. See parameters for DelayedScaling.
device_train_microbatch_size (Union[int, float, str), optional) โ
The number of samples to process on each device per microbatch during training. Gradients are summed over the microbatches per device. If set to
auto
, dynamically decreases device_train_microbatch_size if microbatch is too large for GPU. (default:None
)Note
This is implemented by taking the batch yielded by the
train_dataloader
and splitting it into sections of sizedevice_train_microbatch_size
. If the batch size of the dataloader is not divisible bydevice_train_microbatch_size
, the last section will be potentially smaller.accumulate_train_batch_on_tokens (bool, optional) โ Whether training loss is accumulated over the number of tokens in a batch, rather than the number of samples. Only works if the train data spec implements get_num_tokens_in_batch. Note: If you are using this flag, you can optionally have your get_num_tokens_in_batch function return a dictionary with two keys (total and loss_generating). Composer will then accumulate the batch on loss generating tokens specifically, even though total tokens will be used for any other time involving tokens. (default:
False
)seed (int, optional) โ
The seed used in randomization. If
None
, then a random seed will be created. (default:None
)Note
In order to get reproducible results, call the
seed_all()
function at the start of your script with the seed passed to the trainer. This will ensure any initialization done before the trainer init (ex. model weight initialization) also uses the provided seed.See also
composer.utils.reproducibility
for more details on reproducibility.deterministic_mode (bool, optional) โ
Run the model deterministically. (default:
False
)Note
This is an experimental feature. Performance degradations expected. Certain Torch modules may not have deterministic implementations, which will result in a crash.
Note
In order to get reproducible results, call the
configure_deterministic_mode()
function at the start of your script. This will ensure any initialization done before the trainer init also runs deterministically.See also
composer.utils.reproducibility
for more details on reproducibility.dist_timeout (float, optional) โ Timeout, in seconds, for initializing the distributed process group. (default:
300.0
)ddp_sync_strategy (str | DDPSyncStrategy, optional) โ The strategy to use for synchronizing gradients. Leave unset to let the trainer auto-configure this. See
DDPSyncStrategy
for more details.profiler (Profiler, optional) โ
The profiler, if profiling should be enabled. (default:
None
)See also
See the Profiling Guide for additional information.
python_log_level (str, optional) โ
The Python log level to use for log statements in the
composer
module. (default:None
). If it isNone
, python logging will not be configured (i.e.logging.basicConfig
wonโt be called).See also
The
logging
module in Python.compile_config (dict[str, Any], optional) โ Configuration for torch compile. Only supported with PyTorch 2.0 or higher. Checkout [torch.compile](https://pytorch.org/get-started/pytorch-2.0/) for more details. To use torch compile with default values, set it to empty dictionary
{}
. To use torch compile with custom config, set to a dictionary such as{'mode': 'max-autotune'}
. To disable torch compile, set toNone
. (default:None
)
- close()[source]#
Shutdown the trainer.
See also
Engine.close()
for additional information.
- eval(eval_dataloader=None, subset_num_batches=- 1)[source]#
Run evaluation loop.
Results are stored in
trainer.state.eval_metrics
. Theeval_dataloader
can be provided to either the eval() method or during training init().Examples: .. testcode:
trainer = Trainer( model=model, train_dataloader=train_dataloader, max_duration="2ep", device="cpu", ) trainer.fit() # run eval trainer.eval( eval_dataloader=eval_dataloader, )
Or, if the
eval_dataloader
is provided during init:trainer = Trainer( model=model, eval_dataloader=eval_dataloader, train_dataloader=train_dataloader, max_duration="2ep", device="cpu", ) trainer.fit() # eval_dataloader already provided: trainer.eval()
For multiple metrics or dataloaders, use
Evaluator
to provide identifier names. For example, to run the GLUE task:from composer.core import Evaluator from composer.models.nlp_metrics import BinaryF1Score glue_mrpc_task = Evaluator( label='glue_mrpc', dataloader=mrpc_dataloader, metric_names=['BinaryF1Score', 'MulticlassAccuracy'] ) glue_mnli_task = Evaluator( label='glue_mnli', dataloader=mnli_dataloader, metric_names=['MulticlassAccuracy'] ) trainer = Trainer( ..., eval_dataloader=[glue_mrpc_task, glue_mnli_task], ... )
The metrics used are defined in your modelโs
get_metrics()
method. For more information, see ๐ Evaluation.Note
If evaluating with multiple GPUs using a DistributedSampler with drop_last=False, the last batch will contain duplicate samples, which may affect metrics. To avoid this, as long as the dataset passed to the DistributedSampler has a length defined, Composer will correctly drop duplicate samples.
- Parameters
eval_dataloader (Iterable | DataLoader | DataSpec | Evaluator | Sequence[Evaluator], optional) โ Dataloaders for evaluation. If not provided, defaults to using the
eval_dataloader
provided to the trainer init().subset_num_batches (int, optional) โ Evaluate on this many batches. Default to
-1
(the entire dataloader. Can also be provided in the trainer.__init__() aseval_subset_num_batches
.
- export_for_inference(save_format, save_path, save_object_store=None, sample_input=None, transforms=None, input_names=None, output_names=None)[source]#
Export a model for inference.
- Parameters
save_format (Union[str, ExportFormat]) โ Format to export to. Either
"torchscript"
or"onnx"
.save_path โ (str): The path for storing the exported model. It can be a path to a file on the local disk,
URL (a) โ in a cloud bucket. For example,
my_run/exported_model
.set (or if save_object_store is) โ in a cloud bucket. For example,
my_run/exported_model
.name (the object) โ in a cloud bucket. For example,
my_run/exported_model
.save_object_store (ObjectStore, optional) โ If the
save_path
is in an object name in a cloud bucket (i.e. AWS S3 or Google Cloud Storage), an instance ofObjectStore
which will be used to store the exported model. If this is set toNone
, will save tosave_path
using the trainerโs logger. (default:None
)sample_input (Any, optional) โ Example model inputs used for tracing. This is needed for โonnxโ export. The
sample_input
need not match the batch size you intend to use for inference. However, the model should accept thesample_input
as is. (default:None
)transforms (Sequence[Transform], optional) โ transformations (usually optimizations) that should be applied to the model. Each Transform should be a callable that takes a model and returns a modified model.
input_names (Sequence[str], optional) โ names to assign to the input nodes of the graph, in order. If set to
None
, the keys from the sample_input will be used. Fallbacks to["input"]
.output_names (Sequence[str], optional) โ names to assign to the output nodes of the graph, in order. It set to
None
, it defaults to["output"]
.
- Returns
None
- fit(*, train_dataloader=None, train_dataloader_label='train', train_subset_num_batches=None, spin_dataloaders=None, duration=None, reset_time=False, schedulers=None, scale_schedule_ratio=1.0, step_schedulers_every_batch=None, eval_dataloader=None, eval_subset_num_batches=- 1, eval_interval=1, device_train_microbatch_size=None, precision=None)[source]#
Train the model.
The Composer
Trainer
supports multiple calls tofit()
. Any arguments specified during the call tofit()
will override the values specified when constructing theTrainer
. All arguments are optional, with the following exceptions:The
train_dataloader
must be specified here if not provided when constructing theTrainer
.The
duration
must be specified here if not provided when constructing theTrainer
, or if this is a subsequent call tofit()
.
For example, the following are equivalent:
# The `train_dataloader` and `duration` can be specified # when constructing the Trainer trainer_1 = Trainer( model=model, train_dataloader=train_dataloader, max_duration="1ep", ) trainer_1.fit() # Or, these arguments can be specified on `fit()` trainer_2 = Trainer(model) trainer_2.fit( train_dataloader=train_dataloader, duration="1ep" )
When invoking
fit()
for a subsequent time, eitherreset_time
orduration
must be specified. Otherwise, it is ambiguous for how long to train.If
reset_time
is True, thenfit()
will train for the same amount of time as the previous call (or forduration
if that parameter is also specified). TheState.timestamp
will be reset, causingComposerScheduler
andAlgorithm
instances to start from the beginning, as if it is a new training run. Model gradients, optimizer states, and native PyTorch schedulers will not be reset.If
reset_time
is False, thenfit()
will train for the amount of time specified byduration
. TheState.max_duration
will be incremented byduration
.
For example:
# Construct the trainer trainer = Trainer(max_duration="1ep") # Train for 1 epoch trainer.fit() assert trainer.state.timestamp.epoch == "1ep" # Reset the time to 0, then train for 1 epoch trainer.fit(reset_time=True) assert trainer.state.timestamp.epoch == "1ep" # Train for another epoch (2 epochs total) trainer.fit(duration="1ep") assert trainer.state.timestamp.epoch == "2ep" # Train for another batch (2 epochs + 1 batch total) # It's OK to switch time units! trainer.fit(duration="1ba") assert trainer.state.timestamp.epoch == "2ep" assert trainer.state.timestamp.batch_in_epoch == "1ba" # Reset the time, then train for 3 epochs trainer.fit(reset_time=True, duration="3ep") assert trainer.state.timestamp.epoch == "3ep"
- Parameters
train_dataloader (Iterable | DataSpec | dict[str, Any], optional) โ See
Trainer
.reset_time (bool) โ
Whether to reset the
State.timestamp
to zero values. Defaults to False.If
True
, the timestamp will be zeroed out, causingComposerScheduler
andAlgorithm
instances to start from the beginning, as if it is a new training run. The model will be trained forduration
, if specified, or forState.max_duration
, which would have been provided when constructing theTrainer
or by a previous call tofit()
.Note
Model gradients, optimizer states, and native PyTorch schedulers will not be reset.
If
False
(the default), training time will be incremented from where the previous call tofit()
finished (or from zero, if a new training run). Themax_duration
will be incremented by theduration
parameter.duration (Time[int] | str | int, optional) โ
The duration to train. Can be an integer, which will be interpreted to be epochs, a str (e.g.
1ep
, or10ba
), or aTime
object.If
reset_time
is False (the default), thenState.max_duration
will be converted into the same units as this parameter (if necessary), and then the max duration incremented by the value of this parameter.If
reset_time
is True, thenState.max_duration
will be set to this parameter.optimizers (Optimizer | Sequence[Optimizer], optional) โ See
Trainer
.schedulers (LRScheduler | ComposerScheduler | Sequence[LRScheduler | ComposerScheduler], optional) โ See
Trainer
.step_schedulers_every_batch (bool, optional) โ See
Trainer
.eval_dataloader (Iterable | DataSpec | Evaluator | Sequence[Evaluator], optional) โ See
Trainer
.eval_interval (int | str | Time | (State, Event) -> bool, optional) โ See
Trainer
.device_train_microbatch_size (int | float | str, optional) โ See
Trainer
.
- predict(dataloader, subset_num_batches=- 1, *, return_outputs=True)[source]#
Output model prediction on the provided data.
There are two ways to access the prediction outputs.
With
return_outputs
set to True, the batch predictions will be collected into a list and returned.Via a custom callback, which can be used with
return_outputs
set to False.This technique can be useful if collecting all the outputs from the dataloader would exceed available memory, and you want to write outputs directly to files. For example:
import os import torch from torch.utils.data import DataLoader from composer import Trainer, Callback from composer.loggers import Logger class PredictionSaver(Callback): def __init__(self, folder: str): self.folder = folder os.makedirs(self.folder, exist_ok=True) def predict_batch_end(self, state: State, logger: Logger) -> None: name = f'batch_{int(state.predict_timestamp.batch)}.pt' filepath = os.path.join(self.folder, name) torch.save(state.outputs, filepath) # Also upload the files logger.upload_file(remote_file_name=name, file_path=filepath) trainer = Trainer( ..., callbacks=PredictionSaver('./predict_outputs'), ) trainer.predict(predict_dl, return_outputs=False) print(sorted(os.listdir('./predict_outputs')))
['batch_1.pt', ...]
- Parameters
dataloader (DataLoader | DataSpec) โ The
DataLoader
orDataSpec
for the prediction data.subset_num_batches (int, optional) โ If specified, only perform model prediction on this many batches. This parameter has no effect if it is greater than
len(dataloader)
. If-1
, then the entire loader will be iterated over. (default:-1
)return_outputs (bool, optional) โ If True (the default), then prediction outputs will be (recursively) moved to cpu and accumulated into a list. Otherwise, prediction outputs are discarded after each batch.
- Returns
list โ A list of batch outputs, if
return_outputs
is True. Otherwise, an empty list.
- save_checkpoint(name='ep{epoch}-ba{batch}-rank{rank}', *, weights_only=False)[source]#
Checkpoint the training
State
.- Parameters
name (str, optional) โ See
save_checkpoint()
.weights_only (bool, optional) โ See
save_checkpoint()
.
- Returns
str or None โ See
save_checkpoint()
.
- save_checkpoint_to_save_folder()[source]#
Checkpoints the training
State
using a CheckpointSaver if it exists.- Raises
ValueError โ If
_checkpoint_saver
does not exist.- Returns
None
- property saved_checkpoints#
Returns list of saved checkpoints.
Note
For DeepSpeed, which saves file on every rank, only the files corresponding to the processโs rank will be shown.