class composer.algorithms.SeqLengthWarmup(duration=0.3, min_seq_length=8, max_seq_length=1024, step_size=8, truncate=True, preserve_end_of_sequence=False)[source]#

Progressively increases the sequence length during training.

Changes the sequence length of all tensors in the input batch. The sequence length increases from min_seq_length to max_seq_length in steps of step_size during the first duration fraction of training.

The sequence length is then kept at max_seq_length for the rest of training.

Tensors are either truncated (truncate=True) or reshaped to create new examples from the extra tokens (truncate=False).

This algorithm runs on Event.AFTER_DATALOADER to modify the sequence length of a batch of data after the model and data have been moved to accelerators.


step_size should be a multiple of eight for optimal throughput on NVIDIA GPUs.


Variable input lengths can create CUDA OOM errors. To avoid this, we follow the PyTorch notes and pre-allocate the memory with a blank forward and backward pass.

See the Method Card for more details.


from composer.algorithms import SeqLengthWarmup
from composer import Trainer

seq_length_warmup = SeqLengthWarmup(duration=0.5,

trainer = Trainer(model=model,
  • duration (float, optional) โ€“ Fraction of total training for sequential length learning. Default = 0.3.

  • min_seq_length (int, optional) โ€“ Minimum sequence length to start the warmup. Default = 8.

  • max_seq_length (int, optional) โ€“ Maximum sequence length to stop the warmup. Default = 1024.

  • step_size (int, optional) โ€“ Step size of sequence length. Default = 8.

  • truncate (bool, optional) โ€“ Truncate sequences early, or reshape tensors to create new examples out of the extra tokens. Default: True.

  • preserve_end_of_sequence (bool, optional) โ€“ Preserve the end-of-sequence of the batch when truncating. Useful when input formats include a unique end-of-sequence token. Ignored if truncate=False. Default: False. E.g., if batch["input_ids"] is [[10, 11, 12, 13, 14, 15]] and curr_seq_length=3, "input_ids" in the returned batch would be [[10, 11, 12]] with preserve_end_of_sequence=False and would be [[10, 11, 15]] with preserve_end_of_sequence=True. This behavior applies to any batch tensor with 2 or more dimensions.