composer.functional.set_batch_sequence_length(batch, curr_seq_len, truncate=True, preserve_end_of_sequence=False)[source]#

Set the sequence length of a batch.

Changes the sequence length of all tensors in the provided dictionary to curr_seq_len by either truncating the tensors (truncate=True) or reshaping the tensors to create new examples from the extra tokens (truncate=False).


The schedule for curr_seq_len over training time should be managed outside of this function.


Variable input lengths can create CUDA OOM errors. To avoid this, we follow the PyTorch notes and pre-allocate the memory with a blank forward and backward pass.

  • batch (Dict[str, Tensor]) โ€“ The input batch to the model, must be a dictionary.

  • curr_seq_length (int) โ€“ The desired sequence length to apply.

  • truncate (bool, optional) โ€“ Truncate sequences early, or reshape tensors to create new examples out of the extra tokens. Default: True.

  • preserve_end_of_sequence (bool, optional) โ€“ Preserve the end-of-sequence of the batch when truncating. Useful when input formats include a unique end-of-sequence token. Ignored if truncate=False. Default: False. E.g., if batch["input_ids"] is [[10, 11, 12, 13, 14, 15]] and curr_seq_length=3, "input_ids" in the returned batch would be [[10, 11, 12]] with preserve_end_of_sequence=False and would be [[10, 11, 15]] with preserve_end_of_sequence=True. This behavior applies to any batch tensor with 2 or more dimensions.


Dict[str, Tensor] โ€“ a Mapping of input tensors to the model, where all tensors have curr_seq_len in the second dimension.


import composer.functional as cf

for epoch in range(num_epochs):
    for X, y in train_loader:
        X = cf.set_batch_sequence_length(X, sequence_length)
        y_hat = model(X)
        loss = loss_fn(y_hat, y)