Text Data: Synthetic NLP#
In this tutorial, we will demonstrate how to create a Synthetic dataset, write a synthetic dataset into a streaming format and use the StreamingDataset class to load the dataset.
Recommended Background#
This tutorial assumes that you’re reasonably familiar with the workings of datasets and dataloaders for training deep learning models.
If you’re already familiar with streaming’s dataset classes (Dataset and MDSWriter), that’s great. If not, you may want to pause while working through the tutorial and look at the docs referenced along the way.
Tutorial Goals and Concepts Covered#
The goal of this tutorial is to showcase how to prepare the dataset and use Streaming data loading to iterate and fetch the samples. It will consist of a few steps:
Generate a synthetic dataset
Preparing the dataset for streaming
Streaming the dataset to the local machine
Iterate through the dataset and fetch the samples
Let’s get started!
Setup#
Let’s start by making sure the right packages are installed and imported. We need to install the mosaicml-streaming
package which installs the sufficient dependencies to run this tutorial.
[ ]:
%pip install mosaicml-streaming
# To install from source instead of the last release, comment the command above and uncomment the following one.
# %pip install git+https://github.com/mosaicml/streaming.git
# (Optional) To upload a streaming dataset to an AWS S3 bucket
%pip install awscli
[ ]:
import os
import shutil
from typing import Any, Dict, List, Tuple
import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
We’ll be using Streaming’s MDSWriter
which writes the dataset in Streaming format and StreamingDataset
to load the streaming dataset.
[ ]:
from streaming import MDSWriter, StreamingDataset
Global settings#
For this tutorial, let’s import some of the global setting at the start.
[ ]:
# the location of the "remote" streaming dataset (`sds`).
# Upload `out_root` to your cloud storage provider of choice. If `out_root` is a cloud provider
# path, shard files are automatically uploaded.
out_root = "./sds"
out_train = "./sds/train"
out_val = "./sds/val"
# the location to download the streaming dataset during training
local = './local'
local_train = './local/train'
local_val = './local/val'
# toggle shuffling in dataloader
shuffle_train = True
shuffle_val = False
# training batch size
batch_size = 512
[ ]:
# upload location for the dataset splits (change this if you want to upload to a different location, for example, AWS S3 bucket location)
upload_location = None
if upload_location is None:
upload_train_location = None
upload_val_location = None
else:
upload_train_location = os.path.join(upload_location, 'train')
upload_val_location = os.path.join(upload_location, 'val')
Create a Synthetic NLP dataset#
In this tutorial, we will be creating a synthetic number-saying dataset, i.e. converting a numbers from digits to words, for example, number 123
would spell as one hundred twenty three
. The numbers are generated sequentially with a random positive/negative prefix sign.
Let’s import a utility functions to generate those synthetic number-saying dataset.
[ ]:
# Word representation of a number
ones = ('zero one two three four five six seven eight nine ten eleven twelve thirteen fourteen ' +
'fifteen sixteen seventeen eighteen nineteen').split()
tens = 'twenty thirty forty fifty sixty seventy eighty ninety'.split()
def say(i: int) -> List[str]:
"""Get the word form of a number.
Args:
i (int): The number.
Returns:
List[str]: The number in word form.
"""
if i < 0:
return ['negative'] + say(-i)
elif i <= 19:
return [ones[i]]
elif i < 100:
return [tens[i // 10 - 2]] + ([ones[i % 10]] if i % 10 else [])
elif i < 1_000:
return [ones[i // 100], 'hundred'] + (say(i % 100) if i % 100 else [])
elif i < 1_000_000:
return say(i // 1_000) + ['thousand'] + (say(i % 1_000) if i % 1_000 else [])
elif i < 1_000_000_000:
return say(i // 1_000_000) + ['million'] + (say(i % 1_000_000) if i % 1_000_000 else [])
else:
assert False
def get_numbers(num_train: int, num_val: int) -> Tuple[List[int], List[int]]:
"""Get two non-overlapping splits of a sequential random numbers.
The train sample indices goes from [0, num_train] and val sample indices goes
from [num_train, num_val].
Args:
num_train (int): Number of training samples.
num_val (int): Number of validation samples.
Returns:
Tuple[List[int], List[int]]: The two generated splits.
"""
total = num_train + num_val
numbers = []
bar = tqdm(total=total, leave=False)
i = 0
while i < total:
was = len(numbers)
sign = (np.random.random() < 0.8) * 2 - 1
numbers.append(sign * i)
bar.update(len(numbers) - was)
i += 1
return numbers[:num_train], numbers[num_train:]
Initialize a method to generate a train and validation samples where each sample is a dictionary with attributes {'number': <Integer number>, 'words': <word representation of an integer number as string>}
.
[ ]:
def generate_samples(numbers: List[int]) -> List[Dict[str, Any]]:
"""Generate samples from a list of numbers.
Args:
numbers (List[int]): The numbers.
Returns:
List[Dict[str, Any]]: The corresponding samples.
"""
samples = []
for num in numbers:
words = ' '.join(say(num))
sample = {'number': num, 'words': words}
samples.append(sample)
return samples
def get_dataset(num_train: int, num_val: int) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
"""Generate a number-saying dataset of the given size.
Args:
num_train (int): Number of training samples.
num_val (int): Number of validation samples.
Returns:
Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: The two generated splits.
"""
train_nums, val_nums = get_numbers(num_train, num_val)
train_samples = generate_samples(train_nums)
val_samples = generate_samples(val_nums)
return train_samples, val_samples
Create a non-overlapping train
and val
split dataset of unique random numbers.
[ ]:
# Number of training and validation samples
num_train_samples = 10_000 # 10k samples
num_val_samples = 2000 # 2k samples
# Create the samples.
print(f'Generating synthetic dataset ({num_train_samples} train, {num_val_samples} val)...')
train_samples, val_samples = get_dataset(num_train_samples, num_val_samples)
splits = [
('train', train_samples),
('val', val_samples)
]
Let’s inspect the first train and test sample.
[ ]:
print(f'Train sample: {train_samples[0]}')
print(f'Val sample: {val_samples[0]}')
Convert the dataset to MosaicML Streaming#
We are going to use the MDSWriter
to convert the raw synthetic NLP dataset into a .mds
file format.
For more information on the Streaming MDSWriter
class check out the API reference.
[ ]:
# Mapping of sample keyword with their data type
columns = {
'number': 'int',
'words': 'str',
}
# Compression algorithm to use for dataset
compression = 'zstd:12'
# Hashing algorithm to use for dataset
hashes = ['sha1', 'xxh3_64']
# shard size limit, in bytes
size_limit = 1 << 16 # Override to a small number for more shards.
print(f'Saving dataset (to {out_root})...')
for split, samples in splits:
print(f'* {split}')
dirname = os.path.join(out_root, split)
with MDSWriter(out=dirname, columns=columns, compression=compression,
hashes=hashes, size_limit=size_limit) as out:
for sample in tqdm(samples, leave=False):
out.write(sample)
Now that we’ve written the datasets to out_root
, one can upload them to a cloud storage provider, and we are ready to stream them.
[ ]:
remote_train = upload_train_location or out_train # replace this with your URL for cloud streaming
remote_val = upload_val_location or out_val
Loading the Data#
We extend Streaming’s Dataset to deserialize the data. Let’s verify that iterating over the StreamingDataset
class gives us the exact raw samples in the same deterministic sample order.
Note that StreamingDataset
requires passing in a batch_size
parameter for iteration. This batch_size
is per-device, and should be the same as the DataLoader
batch size.
For more information on the StreamingDataset
class check out the API reference.
[ ]:
# Load the samples back.
print('Walking the dataset:')
print(f'verifying samples for train split')
train_dataset = StreamingDataset(remote=upload_location or out_root, local=local, batch_size=batch_size, split='train', shuffle=False)
for old, new in tqdm(zip(train_samples, train_dataset), total=len(train_samples), leave=False):
assert old == new
print(f'verifying samples for val split')
val_dataset = StreamingDataset(remote=upload_location or out_root, local=local, batch_size=batch_size, split='val', shuffle=False)
for old, new in tqdm(zip(val_samples, val_dataset), total=len(val_samples), leave=False):
assert old == new
We can also visualize the sample(s) by doing pythonic or NumPy indexing on a StreamingDataset
.
[ ]:
# Fetch the 10th sample and print it on a console
print(f'Sample 10: {train_dataset[10]}')
# Fetch multiple samples
indices = [-1, 30, [12, -14], slice(-1, -10, -2), np.array([10, -20])]
for indx in indices:
print(f'Sample {indx}: {train_dataset[indx]}')
Below are some utility methods about the dataset which would be highly useful for debugging and model training. For more information on the StreamingDataset
parameters, check out the API reference.
[ ]:
# Get the total number of samples
print(f'Total number of samples: {train_dataset.num_samples}')
# Get the number of shard files
print(f'Total number of shards: {len(train_dataset.shards)}')
# Get the number of samples inside each shard files.
# Number of samples in each shard can vary based on each sample size.
print(f'Number of samples inside each shards: {train_dataset.samples_per_shard}')
We can now wrap our streaming datasets in a standard PyTorch dataloaders for training!
[ ]:
train_dataloader = DataLoader(train_dataset, batch_size=batch_size)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size)
Cleanup#
That’s it. No need to hang on to the files created by the tutorial…
[ ]:
shutil.rmtree(out_root, ignore_errors=True)
shutil.rmtree(local, ignore_errors=True)
What next?#
You’ve now seen an in-depth look at how to prepare and use streaming datasets with PyTorch.
To continue learning about Streaming, please continue to explore our examples!
Come get involved with MosaicML!#
We’d love for you to get involved with the MosaicML community in any of these ways:
Star Streaming on GitHub#
Help make others aware of our work by starring Streaming on GitHub.
Join the MosaicML Slack#
Head on over to the MosaicML slack to join other ML efficiency enthusiasts. Come for the paper discussions, stay for the memes!
Contribute to Streaming#
Is there a bug you noticed or a feature you’d like? File an issue or make a pull request!