Trainer#

class composer.Trainer(*, model, train_dataloader=None, train_dataloader_label='train', train_subset_num_batches=- 1, max_duration=None, algorithms=None, algorithm_passes=None, optimizers=None, schedulers=None, scale_schedule_ratio=1.0, step_schedulers_every_batch=None, eval_dataloader=None, eval_interval=1, eval_subset_num_batches=- 1, callbacks=None, loggers=None, run_name=None, progress_bar=True, log_to_console=False, console_stream='stderr', console_log_interval='1ba', log_traces=False, auto_log_hparams=False, load_path=None, load_object_store=None, load_weights_only=False, load_strict_model_weights=False, load_progress_bar=True, load_ignore_keys=None, load_exclude_algorithms=None, save_folder=None, save_filename='ep{epoch}-ba{batch}-rank{rank}.pt', save_latest_filename='latest-rank{rank}.pt', save_overwrite=False, save_interval='1ep', save_weights_only=False, save_num_checkpoints_to_keep=- 1, autoresume=False, deepspeed_config=None, fsdp_config=None, device=None, precision=None, grad_accum=1, device_train_microbatch_size=None, seed=None, deterministic_mode=False, dist_timeout=1800.0, ddp_sync_strategy=None, profiler=None, python_log_level=None)[source]#

Train models with Composer algorithms.

The trainer supports models with ComposerModel instances. The Trainer is highly customizable and can support a wide variety of workloads. See the training guide for more information.

Example

Train a model and save a checkpoint:

import os
from composer import Trainer

### Create a trainer
trainer = Trainer(
    model=model,
    train_dataloader=train_dataloader,
    max_duration="1ep",
    eval_dataloader=eval_dataloader,
    optimizers=optimizer,
    schedulers=scheduler,
    device="cpu",
    eval_interval="1ep",
    save_folder="checkpoints",
    save_filename="ep{epoch}.pt",
    save_interval="1ep",
    save_overwrite=True,
)

# Fit and run evaluation for 1 epoch.
# Save a checkpoint after 1 epoch as specified during trainer creation.
trainer.fit()

Load the checkpoint and resume training:

# Get the saved checkpoint filepath
checkpoint_path = trainer.saved_checkpoints.pop()

# Create a new trainer with the `load_path` argument set to the checkpoint path.
trainer = Trainer(
    model=model,
    train_dataloader=train_dataloader,
    max_duration="2ep",
    eval_dataloader=eval_dataloader,
    optimizers=optimizer,
    schedulers=scheduler,
    device="cpu",
    eval_interval="1ep",
    load_path=checkpoint_path,
)

# Continue training and running evaluation where the previous trainer left off
# until the new max_duration is reached.
# In this case it will be one additional epoch to reach 2 epochs total.
trainer.fit()
Parameters
  • model (ComposerModel) โ€“

    The model to train. Can be user-defined or one of the models included with Composer.

    See also

    composer.models for models built into Composer.

  • train_dataloader (Iterable | DataSpec | dict, optional) โ€“

    The dataloader, DataSpec, or dict of DataSpec kwargs for the training data. In order to specify custom preprocessing steps on each data batch, specify a DataSpec instead of a dataloader. It is recommended that the dataloader, whether specified directly or as part of a DataSpec, should be a torch.utils.data.DataLoader.

    Note

    The train_dataloader should yield per-rank batches. Each per-rank batch will then be further divided based on the device_train_microbatch_size parameter. For example, if the desired optimization batch size is 2048 and training is happening across 8 GPUs, then each train_dataloader should yield a batch of size 2048 / 8 = 256. If device_train_microbatch_size = 128, then the per-rank batch will be divided into 256 / 128 = 2 microbatches of size 128.

    If train_dataloader is not specified when constructing the trainer, it must be specified when invoking Trainer.fit().

  • train_dataloader_label (str, optional) โ€“

    The label for the train dataloader. (default: 'train')

    This label is used to index the training metrics in State.train_metrics.

    This parameter has no effect if train_dataloader is not specified.

  • train_subset_num_batches (int, optional) โ€“

    If specified, finish every epoch early after training on this many batches. This parameter has no effect if it is greater than len(train_dataloader). If -1, then the entire dataloader will be iterated over. (default: -1)

    When using the profiler, it can be helpful to set this parameter to the length of the profile schedule. This setting will end each epoch early to avoid additional training that will not be profiled.

    This parameter is ignored if train_dataloader is not specified.

  • max_duration (Time | str | int, optional) โ€“

    The maximum duration to train. Can be an integer, which will be interpreted to be epochs, a str (e.g. 1ep, or 10ba), or a Time object.

    If max_duration is not specified when constructing the trainer, duration must be specified when invoking Trainer.fit().

  • algorithms (Algorithm | Sequence[Algorithm], optional) โ€“

    The algorithms to use during training. If None, then no algorithms will be used. (default: None)

    See also

    composer.algorithms for the different algorithms built into Composer.

  • algorithm_passes ([AlgorithmPass | Tuple[AlgorithmPass, int] | Sequence[AlgorithmPass | Tuple[AlgorithmPass, int]], optional) โ€“

    Optional list of passes to change order in which algorithms are applied. These passes are merged with the default passes specified in Engine. If None, then no additional passes will be used. (default: None)

    See also

    composer.core.Engine for more information.

  • optimizers (Optimizer, optional) โ€“

    The optimizer. If None, will be set to DecoupledSGDW(model.parameters(), lr=0.1). (default: None)

    See also

    composer.optim for the different optimizers built into Composer.

  • schedulers (PyTorchScheduler | ComposerScheduler | Sequence[PyTorchScheduler | ComposerScheduler], optional) โ€“

    The learning rate schedulers. If [] or None, the learning rate will be constant. (default: None).

    See also

    composer.optim.scheduler for the different schedulers built into Composer.

  • scale_schedule_ratio (float, optional) โ€“

    Ratio by which to scale the training duration and learning rate schedules. (default: 1.0)

    E.g., 0.5 makes the schedule take half as many epochs and 2.0 makes it take twice as many epochs. 1.0 means no change.

    This parameter has no effect if schedulers is not specified.

    Note

    Training for less time, while rescaling the learning rate schedule, is a strong baseline approach to speeding up training. E.g., training for half duration often yields minor accuracy degradation, provided that the learning rate schedule is also rescaled to take half as long.

    To see the difference, consider training for half as long using a cosine annealing learning rate schedule. If the schedule is not rescaled, training ends while the learning rate is still ~0.5 of the initial LR. If the schedule is rescaled with scale_schedule_ratio, the LR schedule would finish the entire cosine curve, ending with a learning rate near zero.

  • step_schedulers_every_batch (bool, optional) โ€“ By default, native PyTorch schedulers are updated every epoch, while Composer Schedulers are updated every step. Setting this to True will force schedulers to be stepped every batch, while False means schedulers stepped every epoch. None indicates the default behavior. (default: None)

  • eval_dataloader (DataLoader | DataSpec | Evaluator | Sequence[Evaluator], optional) โ€“

    The DataLoader, DataSpec, Evaluator, or sequence of evaluators for the evaluation data.

    To evaluate one or more specific metrics across one or more datasets, pass in an Evaluator. If a DataSpec or DataLoader is passed in, then all metrics returned by model.get_metrics() will be used during evaluation. None results in no evaluation. (default: None)

  • eval_interval (int | str | Time | (State, Event) -> bool, optional) โ€“

    Specifies how frequently to run evaluation. An integer, which will be interpreted to be epochs, a str (e.g. 1ep, or 10ba), a Time object, or a callable. Defaults to 1 (evaluate every epoch).

    If an integer (in epochs), Time string, or Time instance, the evaluator will be run with this frequency. Time strings or Time instances must have units of TimeUnit.BATCH or TimeUnit.EPOCH.

    Set to 0 to disable evaluation.

    If a callable, it should take two arguments (State, Event) and return a bool representing whether the evaluator should be invoked. The event will be either Event.BATCH_END or Event.EPOCH_END.

    This eval_interval will apply to any Evaluator in eval_dataloader that does not specify an eval_interval or if a dataloader is passed in directly. This parameter has no effect if eval_dataloader is not specified.

    When specifying time string or integer for the eval_interval, the evaluator(s) are also run at the Event.FIT_END if it doesnโ€™t evenly divide the training duration.

  • eval_subset_num_batches (int, optional) โ€“

    If specified, evaluate on this many batches. Defaults to -1, which means to iterate over the entire dataloader.

    This parameter has no effect if eval_dataloader is not specified, it is greater than len(eval_dataloader), or eval_dataloader is an Evaluator and subset_num_batches was specified as part of the Evaluator.

  • callbacks (Callback | Sequence[Callback], optional) โ€“

    The callbacks to run during training. If None, then no callbacks will be run. (default: None).

    See also

    composer.callbacks for the different callbacks built into Composer.

  • loggers (LoggerDestination | Sequence[LoggerDestination], optional) โ€“

    The destinations to log training information to.

    See also

    composer.loggers for the different loggers built into Composer.

  • run_name (str, optional) โ€“ A name for this training run. If not specified, the timestamp will be combined with a coolname, e.g. 1654298855-electric-zebra.

  • progress_bar (bool) โ€“ Whether to show a progress bar. (default: True)

  • log_to_console (bool) โ€“ Whether to print logging statements to the console. (default: False)

  • console_stream (TextIO | str, optional) โ€“ The stream to write to. If a string, it can either be 'stdout' or 'stderr'. (default: sys.stderr)

  • console_log_interval (int | str | Time, optional) โ€“

    Specifies how frequently to log metrics to console. An integer, which will be interpreted to be epochs, a str (e.g. 1ep, or 10ba), a Time object, or a callable. (default: 1ba) Defaults to 1ba (log metrics every batch).

    If an integer (in epochs), Time string, or Time instance, the metrics will be logged with this frequency. Time strings or Time instances must have units of TimeUnit.BATCH or TimeUnit.EPOCH.

    Set to 0 to disable metrics logging to console.

  • log_traces (bool) โ€“ Whether to log traces or not. (default: False)

  • auto_log_hparams (bool) โ€“ Whether to automatically extract hyperparameters. (default: False)

  • load_path (str, optional) โ€“

    The path format string to an existing checkpoint file.

    It can be a path to a file on the local disk, a URL, or if load_object_store is set, the object name for a checkpoint in a cloud bucket. If a URI is specified, load_object_store does not need to be set.

    When using Deepspeed ZeRO, checkpoints are sharded by rank. Instead of hard-coding the rank in the path, use the following format variables:

    Variable

    Description

    {rank}

    The global rank, as returned by get_global_rank().

    {local_rank}

    The local rank of the process, as returned by get_local_rank().

    {node_rank}

    The node rank, as returned by get_node_rank().

    For example, suppose that checkpoints are stored in the following structure:

    my_model/ep1-rank0.tar
    my_model/ep1-rank1.tar
    my_model/ep1-rank2.tar
    ...
    

    Then, load_path should be set to my_model/ep1-rank{rank}.tar, and all ranks will load the correct state.

    If None then no checkpoint will be loaded. (default: None)

  • load_object_store (Union[ObjectStore, LoggerDestination], optional) โ€“

    If the load_path is in an object store (i.e. AWS S3 or Google Cloud Storage), an instance of ObjectStore or LoggerDestination which will be used to retreive the checkpoint. Otherwise, if the checkpoint is a local filepath, set to None. Also, it can be None if the load_path is an S3 URI because the appropriate object store will be automatically constructed in that case. Ignored if load_path is None. (default: None)

    Example:

    from composer import Trainer
    from composer.utils import LibcloudObjectStore
    
    # Create the object store provider with the specified credentials
    creds = {"key": "object_store_key",
             "secret": "object_store_secret"}
    store = LibcloudObjectStore(provider="s3",
                                container="my_container",
                                provider_kwargs=creds)
    
    checkpoint_path = "./path_to_the_checkpoint_in_object_store"
    
    # Create a trainer which will load a checkpoint from the specified object store
    trainer = Trainer(
        model=model,
        train_dataloader=train_dataloader,
        max_duration="10ep",
        eval_dataloader=eval_dataloader,
        optimizers=optimizer,
        schedulers=scheduler,
        device="cpu",
        eval_interval="1ep",
        load_path=checkpoint_path,
        load_object_store=store,
    )
    

  • load_weights_only (bool, optional) โ€“ Whether or not to only restore the weights from the checkpoint without restoring the associated state. Ignored if load_path is None. (default: False)

  • load_strict_model_weights (bool, optional) โ€“ Ensure that the set of weights in the checkpoint and model must exactly match. Ignored if load_path is None. (default: False)

  • load_progress_bar (bool, optional) โ€“ Display the progress bar for downloading the checkpoint. Ignored if load_path is either None or a local file path. (default: True)

  • load_ignore_keys (List[str] | (Dict) -> None, optional) โ€“

    A list of paths for the state_dict of the checkpoint, which, when provided, will be ignored from the state_dict before a checkpoint is loaded. Each path is a list of strings specifying the keys to index into state_dict joined together with / as a separator (as PyTorch uses . in parameter names). If a prefix is provided, all children are also ignored (see Example 2). See composer.core.state for the structure of state_dict.

    Example 1: load_ignore_keys = ["state/model/layer1.weights", "state/model/layer1.bias"] would ignore layer 1 weights and bias.

    Example 2: load_ignore_keys = ["state/model/*"] would ignore the entire model, which would have the same effect as the previous example if there was only 1 layer.

    Example 3: load_ignore_keys = ["state/model/layer*.weights"] would ignore all weights in the model.

    Example 4: load_ignore_keys = ["state/rank_zero_seed", "rng"] would reset all randomness when loading the checkpoint.

    If a callable, it should take one argument which is the state_dict. The callable is free to arbitrarily modify the state_dict before it is loaded.

    (default: None)

  • load_exclude_algorithms (List[str], optional) โ€“

    A list of algorithm names to exclude from loading. By default, algorithms with required_on_load=True which were enabled when training the loaded checkpoint are automatically applied unless they conflict with a user specified algorithm. These algorithms often change the model, and not applying them could result in certain layers not having weights loaded.

    Example 1: load_exclude_algorithms = ["BlurPool"] would exclude BlurPool from loading.

    Example 2: load_exclude_algorithms = ["FusedLayerNorm", "Alibi"] would exclude FusedLayerNorm and Alibi from loading.

    (default: None)

  • save_folder (str, optional) โ€“

    Format string for the folder where checkpoints are saved. If None, checkpoints will not be saved. Can also be a URI for S3 paths only. In the case of an S3 URI, the appropriate ~.RemoteUploader object will be created automatically. (default: None)

    See also

    CheckpointSaver

    Note

    For fine-grained control on checkpoint saving (e.g. to save different types of checkpoints at different intervals), leave this parameter as None, and instead pass instance(s) of CheckpointSaver directly as callbacks.

  • save_filename (str, optional) โ€“

    A format string describing how to name checkpoints. This parameter has no effect if save_folder is None. (default: "ep{epoch}-ba{batch}-rank{rank}.pt")

    See also

    CheckpointSaver

  • save_latest_filename (str, optional) โ€“

    A format string for the name of a symlink (relative to save_folder) that points to the last saved checkpoint. This parameter has no effect if save_folder is None. To disable symlinking, set this to None. (default: "latest-rank{rank}.pt")

    See also

    CheckpointSaver

  • save_overwrite (bool, optional) โ€“

    Whether existing checkpoints should be overridden. This parameter has no effect if save_folder is None. (default: False)

    See also

    CheckpointSaver

  • save_interval (Time | str | int | (State, Event) -> bool) โ€“

    A Time, time-string, integer (in epochs), or a function that takes (state, event) and returns a boolean whether a checkpoint should be saved. This parameter has no effect if save_folder is None. (default: '1ep')

    See also

    CheckpointSaver

  • save_weights_only (bool, optional) โ€“

    Whether to save only the model weights instead of the entire training state. This parameter has no effect if save_folder is None. (default: False)

    See also

    CheckpointSaver

  • save_num_checkpoints_to_keep (int, optional) โ€“

    The number of checkpoints to keep locally. The oldest checkpoints are removed first. Set to -1 to keep all checkpoints locally. (default: -1)

    Checkpoints will be removed after they have been uploaded. For example, when this callback is used in conjunction with the RemoteUploaderDownloader, set this parameter to 0 to immediately delete checkpoints from the local disk after they have been uploaded to the object store.

    This parameter only controls how many checkpoints are kept locally; checkpoints are not deleted from remote file systems.

  • autoresume (bool, optional) โ€“

    Whether or not to enable autoresume, which allows for stopping and resuming training. This allows use of spot instances, as the training run is now fault tolerant. This parameter requires save_folder and run_name to be specified and save_overwrite to be False. (default: False)

    When enabled, the save_folder is checked for checkpoints of the format "{save_folder}/{save_latest_filename}", which are loaded to continue training. If no local checkpoints are found, each logger is checked for potential remote checkpoints named "{save_folder}/{save_latest_filename}". Finally, if no logged checkpoints are found, load_path is used to load a checkpoint if specified. This should only occur at the start of a run using autoresume.

    For example, to run a fine-tuning run on a spot instance, load_path would be set to the original weights and an object store logger would be added. In the original run, load_path would be used to get the starting checkpoint. For any future restarts, such as due to the spot instance being killed, the loggers would be queried for the latest checkpoint the object store logger would be downloaded and used to resume training.

  • deepspeed_config (Dict[str, Any], optional) โ€“

    Configuration for DeepSpeed, formatted as a JSON according to DeepSpeedโ€™s documentation. (default: None)

    To use DeepSpeed with default values, set to the empty dictionary {}. To disable DeepSpeed (the default), set to None.

  • fsdp_config (Dict[str, Any], optional) โ€“ Configuration for FSDP. See FSDP Documentation for more details. To use FSDP with default values, set to the empty dictionary {}. To disable FSDP, set to None. (default: None)

  • device (Device | str, optional) โ€“

    The device to use for training, which can be 'cpu', 'gpu', 'tpu', or 'mps'. (default: None)

    The default behavior sets the device to 'gpu' if CUDA is available, and otherwise 'cpu'.

  • precision (Precision | str, optional) โ€“ Numerical precision to use for training. One of fp32, amp_bf16 or amp_fp16 (recommended). (default: Precision.FP32 if training on CPU; Precision.AMP_FP16 if training on GPU)

  • grad_accum (Union[int, str], optional) โ€“

    The number of microbatches to split a per-device batch into. Gradients are summed over the microbatches per device. If set to auto, dynamically increases grad_accum if microbatch is too large for GPU. (default: 1)

    Note

    This is implemented by taking the batch yielded by the train_dataloader and splitting it into grad_accum sections. Each section is of size train_dataloader // grad_accum. If the batch size of the dataloader is not divisible by grad_accum, then the last section will be of size batch_size mod grad_accum.

    Deprecated since version 0.12: Please use device_train_microbatch_size.

  • device_train_microbatch_size (Union[int, str), optional) โ€“

    The number of samples to process on each device per microbatch during training. Gradients are summed over the microbatches per device. If set to auto, dynamically decreases device_train_microbatch_size if microbatch is too large for GPU. (default: None)

    Note

    This is implemented by taking the batch yielded by the train_dataloader and splitting it into sections of size device_train_microbatch_size. If the batch size of the dataloader is not divisible by device_train_microbatch_size, the last section will be potentially smaller.

  • seed (int, optional) โ€“

    The seed used in randomization. If None, then a random seed will be created. (default: None)

    Note

    In order to get reproducible results, call the seed_all() function at the start of your script with the seed passed to the trainer. This will ensure any initialization done before the trainer init (ex. model weight initialization) also uses the provided seed.

    See also

    composer.utils.reproducibility for more details on reproducibility.

  • deterministic_mode (bool, optional) โ€“

    Run the model deterministically. (default: False)

    Note

    This is an experimental feature. Performance degradations expected. Certain Torch modules may not have deterministic implementations, which will result in a crash.

    Note

    In order to get reproducible results, call the configure_deterministic_mode() function at the start of your script. This will ensure any initialization done before the trainer init also runs deterministically.

    See also

    composer.utils.reproducibility for more details on reproducibility.

  • dist_timeout (float, optional) โ€“ Timeout, in seconds, for initializing the distributed process group. (default: 1800.0)

  • ddp_sync_strategy (str | DDPSyncStrategy, optional) โ€“ The strategy to use for synchronizing gradients. Leave unset to let the trainer auto-configure this. See DDPSyncStrategy for more details.

  • profiler (Profiler, optional) โ€“

    The profiler, if profiling should be enabled. (default: None)

    See also

    See the Profiling Guide for additional information.

  • python_log_level (str, optional) โ€“

    The Python log level to use for log statements in the composer module. (default: None). If it is None, python logging will not be configured (i.e. logging.basicConfig wonโ€™t be called).

    See also

    The logging module in Python.

state#

The State object used to store training state.

Type

State

evaluators#

The Evaluator objects to use for validation during training.

Type

List[Evaluator]

logger#

The Logger used for logging.

Type

Logger

engine#

The Engine used for running callbacks and algorithms.

Type

Engine

close()[source]#

Shutdown the trainer.

See also

Engine.close() for additional information.

eval(eval_dataloader=None, subset_num_batches=- 1)[source]#

Run evaluation loop.

Results are stored in trainer.state.eval_metrics. The eval_dataloader can be provided to either the eval() method or during training init().

Examples: .. testcode:

trainer = Trainer(
    model=model,
    train_dataloader=train_dataloader,
    max_duration="2ep",
    device="cpu",
)

trainer.fit()

# run eval
trainer.eval(
    eval_dataloader=eval_dataloader,
)

Or, if the eval_dataloader is provided during init:

trainer = Trainer(
    model=model,
    eval_dataloader=eval_dataloader,
    train_dataloader=train_dataloader,
    max_duration="2ep",
    device="cpu",
)

trainer.fit()

# eval_dataloader already provided:
trainer.eval()

For multiple metrics or dataloaders, use Evaluator to provide identifier names. For example, to run the GLUE task:

from composer.core import Evaluator
from composer.models.nlp_metrics import BinaryF1Score

glue_mrpc_task = Evaluator(
    label='glue_mrpc',
    dataloader=mrpc_dataloader,
    metric_names=['BinaryF1Score', 'Accuracy']
)

glue_mnli_task = Evaluator(
    label='glue_mnli',
    dataloader=mnli_dataloader,
    metric_names=['Accuracy']
)

trainer = Trainer(
    ...,
    eval_dataloader=[glue_mrpc_task, glue_mnli_task],
    ...
)

The metrics used are defined in your modelโ€™s get_metrics() method. For more information, see ๐Ÿ“Š Evaluation.

Note

This eval API was recently changed to better much the trainer fit API. Please migrate your code to using the new design here. For backwards compatibility, the old API can still be invoked by calling _eval_loop(), however this is not recommended as this may be removed in the future.

Parameters
  • eval_dataloader (DataLoader | DataSpec | Evaluator | Sequence[Evaluator], optional) โ€“ Dataloaders for evaluation. If not provided, defaults to using the eval_dataloader provided to the trainer init().

  • subset_num_batches (int, optional) โ€“ Evaluate on this many batches. Default to -1 (the entire dataloader. Can also be provided in the trainer init()as eval_subset_num_batches.

export_for_inference(save_format, save_path, save_object_store=None, sample_input=None, transforms=None)[source]#

Export a model for inference.

Parameters
  • save_format (Union[str, ExportFormat]) โ€“ Format to export to. Either "torchscript" or "onnx".

  • save_path โ€“ (str): The path for storing the exported model. It can be a path to a file on the local disk,

  • URL (a) โ€“ in a cloud bucket. For example, my_run/exported_model.

  • set (or if save_object_store is) โ€“ in a cloud bucket. For example, my_run/exported_model.

  • name (the object) โ€“ in a cloud bucket. For example, my_run/exported_model.

  • save_object_store (ObjectStore, optional) โ€“ If the save_path is in an object name in a cloud bucket (i.e. AWS S3 or Google Cloud Storage), an instance of ObjectStore which will be used to store the exported model. If this is set to None, will save to save_path using the trainerโ€™s logger. (default: None)

  • sample_input (Any, optional) โ€“ Example model inputs used for tracing. This is needed for โ€œonnxโ€ export. The sample_input need not match the batch size you intend to use for inference. However, the model should accept the sample_input as is. (default: None)

  • transforms (Sequence[Transform], optional) โ€“ transformations (usually optimizations) that should be applied to the model. Each Transform should be a callable that takes a model and returns a modified model.

Returns

None

fit(*, train_dataloader=None, train_dataloader_label='train', train_subset_num_batches=None, duration=None, reset_time=False, schedulers=None, scale_schedule_ratio=1.0, step_schedulers_every_batch=None, eval_dataloader=None, eval_subset_num_batches=- 1, eval_interval=1, grad_accum=None, device_train_microbatch_size=None, precision=None)[source]#

Train the model.

The Composer Trainer supports multiple calls to fit(). Any arguments specified during the call to fit() will override the values specified when constructing the Trainer. All arguments are optional, with the following exceptions:

  • The train_dataloader must be specified here if not provided when constructing the Trainer.

  • The duration must be specified here if not provided when constructing the Trainer, or if this is a subsequent call to fit().

For example, the following are equivalent:

# The `train_dataloader` and `duration` can be specified
# when constructing the Trainer
trainer_1 = Trainer(
    model=model,
    train_dataloader=train_dataloader,
    max_duration="1ep",
)
trainer_1.fit()

# Or, these arguments can be specified on `fit()`
trainer_2 = Trainer(model)
trainer_2.fit(
    train_dataloader=train_dataloader,
    duration="1ep"
)

When invoking fit() for a subsequent time, either reset_time or duration must be specified. Otherwise, it is ambiguous for how long to train.

  • If reset_time is True, then fit() will train for the same amount of time as the previous call (or for duration if that parameter is also specified). The State.timestamp will be reset, causing ComposerScheduler and Algorithm instances to start from the beginning, as if it is a new training run. Model gradients, optimizer states, and native PyTorch schedulers will not be reset.

  • If reset_time is False, then fit() will train for the amount of time specified by duration. The State.max_duration will be incremented by duration.

For example:

# Construct the trainer
trainer = Trainer(max_duration="1ep")

# Train for 1 epoch
trainer.fit()
assert trainer.state.timestamp.epoch == "1ep"

# Reset the time to 0, then train for 1 epoch
trainer.fit(reset_time=True)
assert trainer.state.timestamp.epoch == "1ep"

# Train for another epoch (2 epochs total)
trainer.fit(duration="1ep")
assert trainer.state.timestamp.epoch == "2ep"

# Train for another batch (2 epochs + 1 batch total)
# It's OK to switch time units!
trainer.fit(duration="1ba")
assert trainer.state.timestamp.epoch == "2ep"
assert trainer.state.timestamp.batch_in_epoch == "1ba"

# Reset the time, then train for 3 epochs
trainer.fit(reset_time=True, duration="3ep")
assert trainer.state.timestamp.epoch == "3ep"
Parameters
  • train_dataloader (Iterable | DataSpec | Dict[str, Any], optional) โ€“ See Trainer.

  • train_dataloader_label (str, optional) โ€“ See Trainer.

  • train_subset_num_batches (int, optional) โ€“ See Trainer.

  • reset_time (bool) โ€“

    Whether to reset the State.timestamp to zero values. Defaults to False.

    If True, the timestamp will be zeroed out, causing ComposerScheduler and Algorithm instances to start from the beginning, as if it is a new training run. The model will be trained for duration, if specified, or for State.max_duration, which would have been provided when constructing the Trainer or by a previous call to fit().

    Note

    Model gradients, optimizer states, and native PyTorch schedulers will not be reset.

    If False (the default), training time will be incremented from where the previous call to fit() finished (or from zero, if a new training run). The max_duration will be incremented by the duration parameter.

  • duration (Time[int] | str | int, optional) โ€“

    The duration to train. Can be an integer, which will be interpreted to be epochs, a str (e.g. 1ep, or 10ba), or a Time object.

    If reset_time is False (the default), then State.max_duration will be converted into the same units as this parameter (if necessary), and then the max duration incremented by the value of this parameter.

    If reset_time is True, then State.max_duration will be set to this parameter.

  • optimizers (Optimizer | Sequence[Optimizer], optional) โ€“ See Trainer.

  • schedulers (PyTorchScheduler | ComposerScheduler | Sequence[PyTorchScheduler | ComposerScheduler], optional) โ€“ See Trainer.

  • scale_schedule_ratio (float, optional) โ€“ See Trainer.

  • step_schedulers_every_batch (bool, optional) โ€“ See Trainer.

  • eval_dataloader (Iterable | DataSpec | Evaluator | Sequence[Evaluator], optional) โ€“ See Trainer.

  • eval_subset_num_batches (int, optional) โ€“ See Trainer.

  • eval_interval (int | str | Time | (State, Event) -> bool, optional) โ€“ See Trainer.

  • grad_accum (int | str, optional) โ€“ See Trainer.

  • device_train_microbatch_size (int | str, optional) โ€“ See Trainer.

  • precision (Precision | str, optional) โ€“ See Trainer.

predict(dataloader, subset_num_batches=- 1, *, return_outputs=True)[source]#

Output model prediction on the provided data.

There are two ways to access the prediction outputs.

  1. With return_outputs set to True, the batch predictions will be collected into a list and returned.

  2. Via a custom callback, which can be used with return_outputs set to False.

    This technique can be useful if collecting all the outputs from the dataloader would exceed available memory, and you want to write outputs directly to files. For example:

    import os
    import torch
    
    from torch.utils.data import DataLoader
    
    from composer import Trainer, Callback
    from composer.loggers import Logger
    
    class PredictionSaver(Callback):
        def __init__(self, folder: str):
            self.folder = folder
            os.makedirs(self.folder, exist_ok=True)
    
        def predict_batch_end(self, state: State, logger: Logger) -> None:
            name = f'batch_{int(state.predict_timestamp.batch)}.pt'
            filepath = os.path.join(self.folder, name)
            torch.save(state.outputs, filepath)
    
            # Also upload the files
            logger.upload_file(remote_file_name=name, file_path=filepath)
    
    trainer = Trainer(
        ...,
        callbacks=PredictionSaver('./predict_outputs'),
    )
    
    trainer.predict(predict_dl, return_outputs=False)
    
    print(sorted(os.listdir('./predict_outputs')))
    
    ['batch_1.pt', ...]
    
Parameters
  • dataloader (DataLoader | DataSpec) โ€“ The DataLoader or DataSpec for the prediction data.

  • subset_num_batches (int, optional) โ€“ If specified, only perform model prediction on this many batches. This parameter has no effect if it is greater than len(dataloader). If -1, then the entire loader will be iterated over. (default: -1)

  • return_outputs (bool, optional) โ€“ If True (the default), then prediction outputs will be (recursively) moved to cpu and accumulated into a list. Otherwise, prediction outputs are discarded after each batch.

Returns

List โ€“ A list of batch outputs, if return_outputs is True. Otherwise, an empty list.

save_checkpoint(name='ep{epoch}-ba{batch}-rank{rank}', *, weights_only=False)[source]#

Checkpoint the training State.

Parameters
Returns

str or None โ€“ See save_checkpoint().

property saved_checkpoints#

Returns list of saved checkpoints.

Note

For DeepSpeed, which saves file on every rank, only the files corresponding to the processโ€™s rank will be shown.