- composer.functional.set_batch_sequence_length(batch, curr_seq_len, truncate=True, preserve_end_of_sequence=False)#
Set the sequence length of a batch.
Changes the sequence length of all tensors in the provided dictionary to
curr_seq_lenby either truncating the tensors (
truncate=True) or reshaping the tensors to create new examples from the extra tokens (
The schedule for
curr_seq_lenover training time should be managed outside of this function.
Variable input lengths can create CUDA OOM errors. To avoid this, we follow the PyTorch notes and pre-allocate the memory with a blank forward and backward pass.
batch (Dict[str, Tensor]) – The input batch to the model, must be a dictionary.
curr_seq_length (int) – The desired sequence length to apply.
truncate (bool, optional) – Truncate sequences early, or reshape tensors to create new examples out of the extra tokens. Default:
preserve_end_of_sequence (bool, optional) – Preserve the end-of-sequence of the batch when truncating. Useful when input formats include a unique end-of-sequence token. Ignored if
False. E.g., if
[[10, 11, 12, 13, 14, 15]]and
"input_ids"in the returned batch would be
[[10, 11, 12]]with
preserve_end_of_sequence=Falseand would be
[[10, 11, 15]]with
preserve_end_of_sequence=True. This behavior applies to any batch tensor with 2 or more dimensions.
Dict[str, Tensor] – a Mapping of input tensors to the model, where all tensors have curr_seq_len in the second dimension.
import composer.functional as cf for epoch in range(num_epochs): for X, y in train_loader: X = cf.set_batch_sequence_length(X, sequence_length) y_hat = model(X) loss = loss_fn(y_hat, y)