SeqLengthWarmup#
- class composer.algorithms.SeqLengthWarmup(duration=0.3, min_seq_length=8, max_seq_length=1024, step_size=8, truncate=True, preserve_end_of_sequence=False)[source]#
- Progressively increases the sequence length during training. - Changes the sequence length of all tensors in the input batch. The sequence length increases from - min_seq_lengthto- max_seq_lengthin steps of- step_sizeduring the first- durationfraction of training.- The sequence length is then kept at - max_seq_lengthfor the rest of training.- Tensors are either truncated ( - truncate=True) or reshaped to create new examples from the extra tokens (- truncate=False).- This algorithm runs on - Event.AFTER_DATALOADERto modify the sequence length of a batch of data after the model and data have been moved to accelerators.- Note - step_sizeshould be a multiple of eight for optimal throughput on NVIDIA GPUs.- Note - Variable input lengths can create CUDA OOM errors. To avoid this, we follow the PyTorch notes and pre-allocate the memory with a blank forward and backward pass. - See the Method Card for more details. - Example: - from composer.algorithms import SeqLengthWarmup from composer import Trainer seq_length_warmup = SeqLengthWarmup(duration=0.5, min_seq_length=8, max_seq_length=1024, step_size=8, truncate=True, preserve_end_of_sequence=False) trainer = Trainer(model=model, train_dataloader=train_dataloader, max_duration="1ep", algorithms=[seq_length_warmup]) - Parameters
- duration (float, optional) โ Fraction of total training for sequential length learning. Default = - 0.3.
- min_seq_length (int, optional) โ Minimum sequence length to start the warmup. Default = - 8.
- max_seq_length (int, optional) โ Maximum sequence length to stop the warmup. Default = - 1024.
- step_size (int, optional) โ Step size of sequence length. Default = - 8.
- truncate (bool, optional) โ Truncate sequences early, or reshape tensors to create new examples out of the extra tokens. Default: - True.
- preserve_end_of_sequence (bool, optional) โ Preserve the end-of-sequence of the batch when truncating. Useful when input formats include a unique end-of-sequence token. Ignored if - truncate=False. Default:- False. E.g., if- batch["input_ids"]is- [[10, 11, 12, 13, 14, 15]]and- curr_seq_length=3,- "input_ids"in the returned batch would be- [[10, 11, 12]]with- preserve_end_of_sequence=Falseand would be- [[10, 11, 15]]with- preserve_end_of_sequence=True. This behavior applies to any batch tensor with 2 or more dimensions.