Fast Resumption#
Being resistant to timeouts, hardware failures, or other errors is crucial to efficient distributed training. While other datasets require iterating through previously seen samples before resuming, Streaming allows for immediate and deterministic resumption in the middle of an epoch by being stateful.
Saving and loading state#
To get fast, deterministic mid-epoch resumption, make sure to use the streaming.StreamingDataLoader
object. StreamingDataLoader works in conjunction with StreamingDataset to save and load dataset state. It works exactly like a normal PyTorch DataLoader.
When checkpointing, simply call the state_dict
method of StreamingDataLoader
and save it along with your checkpoint. Then, when resuming, call load_state_dict
with the saved state, and you’ll be running in no time. Here’s an example:
from streaming import StreamingDataset
from streaming import StreamingDataLoader
dataset = StreamingDataset(local='/tmp/cache', remote='s3://remote/dataset', batch_size=1)
dataloader = StreamingDataLoader(dataset, batch_size=1)
# Here, we assume each sample in our dataset has fields 'x' and 'y'.
# We save the dataloader state after 4 batches, and stop after 6 batches.
state_dict = None
for i, batch in enumerate(dataloader):
print(i, batch['x'], batch['y'])
if i == 4:
state_dict = dataloader.state_dict()
if i == 6:
break
Now, we’ve completed 4 batches and seen 6, when training has “stopped”. This is akin to a training job failing some time after a checkpointing interval. Now, we resume from where we left off:
# Create a new dataset
dataset_2 = StreamingDataset(local='cache', remote='path', batch_size=1)
dataloader_2 = StreamingDataLoader(dataset_2, batch_size=1)
# Load in the state dict that was previously saved
dataloader_2.load_state_dict(state_dict)
# Iterate over the dataset, which will start from batch 5 now.
for i, batch in enumerate(dataloader_2):
print(i, batch['x'], batch['y'])
Resumption with Composer#
When training with Composer, our open-soure deep learning training library built on top of PyTorch, fast resumption is handled automatically. Composer and Streaming work seamlessly together to provide efficient, scalable neural network training.