A streaming dataset which fetches samples from anywhere and yields them on the fly, with caching/eviction and random access features:
- Very flexible in terms of data sources:
- Can be constructed from a CSV/parquet/json index, file directory, CSV/parquet/json wildcard or good old "
split_file_texts.txt" - Supports remote and local data sources (e.g., S3, GCS, HTTP, local files, etc.)
- Can be constructed from a CSV/parquet/json index, file directory, CSV/parquet/json wildcard or good old "
- Has caching and eviction logic, so that you can efficiently work with large datasets without hitting disk space limits.
- Has standard data processing transforms for images, videos, audios, text, and metadata.
- Supports random access (through blocking calls)!
- Has a built-int multi-stream dataloader (i.e., streaming from multiple data sources in parallel) with various mixing strategy between the streams.
The package is available on the corporate PyPI:
pip install streaming-datasetThere is also a docker image for SnapVideo-V3 with streaming-dataset pre-installed:
streaming-dataset==0.4.3:440036398022.dkr.ecr.us-west-2.amazonaws.com/facecraft-ml:genai-video-aws-fa3-h100-torch271-126-sds-0-4-3streaming-dataset==0.4.6:440036398022.dkr.ecr.us-west-2.amazonaws.com/facecraft-ml:genai-video-aws-fa3-h100-torch271-126-sds-0-4-6
Here is an example of how to use the streaming dataset for 2 streams. Let's assume that the first stream is a remote S3 folder of videos, the second stream is given with index.parquet file.
For the first stream, we would be loading it directly from S3, but for the second stream, we'll go through the process of preparing the data from BigQuery.
First, create a simple config for your output BQ table, i.e. composeme-v2.yaml:
bq_project: "research-prototypes"
sql_query: "
SELECT *
FROM `research-prototypes.generative_ai_data_platform_test.personalization_getty_dataset`
WHERE aes_score > 4
AND caption IS NOT NULL
AND getty_caption IS NOT NULL
AND getty_title IS NOT NULL
AND Body2DPoseHandsFace IS NOT NULL
AND InstanceSegmentation IS NOT NULL
AND FashionSegmentationImaterialist IS NOT NULL
ORDER BY RAND()
"
s3_destination_path: s3://snap-genvid/datasets/sds-index-files/composeme-v2.parquet
s3_bucket_region: us-west-2 # The region of the S3 bucket. Can be left empty, but would lead to an error in case of a mismatch between $AWS_REGION in the env and the actual region of the bucket.
recompute: true
val_ratio: 0.1 # The fraction of the dataset to use for validation dataset.
max_num_val_rows: 10000 # The maximum number of rows in the validation dataset.
local_tmp_dir: ~ # Local temporary directory where the merged parquet will be saved to if provided (needed for large 20M+ rows outputs). You can likely use `/lssd/index-exports-tmp`.
gcs_tmp_dir: ~ # Where to save intermediate results (needed for huge 70M+ rows outputs). You can likely use `gs://dlahiri/index-exports-tmp`
row_group_size: 20000 # Number of rows per parquet row group.Note: make sure that s3_destination_path is in the correct AWS region for your future training job.
Otherwise, there might be problems when fetching parquet chunks from S3.
Also, it's important to specify local_tmp_dir/gcs_tmp_dir to push the intermediate results through for large files (20M+ rows or 100GB+).
Otherwise, the job might fail either becase S3 multi-part upload would attempt to use too many parts (over 10K limit) or BQ auth token would expire.
Then, install the script env and run the BQ export script:
pip install genml-training-tools # Install this within the first hour of job start
pip install --upgrade google-cloud google-cloud-bigquery google-cloud-storage db-dtypes pandas pyarrow s3fs loguru pydantic PyYAML boto3 google-cloud-bigquery-storage pyarrow
python scripts/construct_index_from_bq.py composeme-v2.yamlIt will create a single parquet index file and upload it to S3.
It will also create a validation index file with val_ratio fraction of the rows (up to max_num_val_rows).
There are multiple examples for different modalities in the examples/ folder.
For brevity, we won't repeat them here, and just show some basic usage with a single stream.
import os
from typing import Callable
os.environ['AWS_REGION'] = 'us-west-2' # Make sure that AWS_REGION is set to the region of your S3 bucket.
os.environ['SDS_LOG_LEVEL'] = 'DEBUG' # Set the log level to DEBUG to see more logs.
import torch
from PIL import Image
from sds.dataset import StreamingDataset
from sds.transforms import presets
def build_transforms() -> list[Callable]:
# SDS contains some standard transforms for common use cases. Let's use the one for images.
image_transforms = presets.create_standard_image_pipeline(
image_field='personalization_image_path', # Where the downloading field is located.
resolution=(256, 256), # Resize the image to this resolution.
)
# Let's only keep the 'image' and 'media_id' fields in the final output, since there is a lot of unused stuff in the index file.
# (This unused stuff can break pytorch's default collate function.)
sample_cleaning_transform = presets.EnsureFieldsTransform(fields_whitelist=['image', 'media_id'], drop_others=True)
return image_transforms + [HorizontalFlipTransform(), sample_cleaning_transform]
class HorizontalFlipTransform:
# A simple horizontal flip transform just for illustration purposes
# Transform is just a class with a __call__ method which takes a sample dict and returns a sample dict.
def __call__(self, sample: dict) -> dict:
assert 'image' in sample, f'Image field is missing in the sample: {sample.keys()}'
assert isinstance(sample['image'], torch.Tensor), f'Image is not a torch.Tensor, but {type(sample["image"])}'
assert sample['image'].ndim == 3, f'Wrong image shape: {sample["image"].shape}'
assert sample['image'].shape[0] in [1, 3], f'Wrong image shape: {sample["image"].shape}'
sample['image'] = torch.flip(sample['image'], dims=[2]) if torch.rand(1) < 0.5 else sample['image'] # [c, h, w]
return sample
def main():
dataset = StreamingDataset(
src='s3://snap-genvid/datasets/sds-index-files/composeme-v2.parquet',
dst='/tmp/where/to/download', # Where to download the samples.
transforms=build_transforms(), # A list of transforms to apply.
# Which columns to download from the index. Their values should be URLs/paths.
# Once downloaded, a sample dict will be constructed, with the column names pointing to the local paths.
# After that, all the transforms will be applied to the sample dict.
columns_to_download=['personalization_image_path', 'personalization_caption_path'],
index_col_name='media_id', # Name of the column to use as the index (should have unique values for samples).
num_downloading_workers=5, # How many parallel downloading workers to use.
prefetch=100, # How many samples to pre-download ahead of time.
cache_limit='10gb', # How much disk space to use for caching.
allow_missing_columns=False, # Should we ignore or not ignore samples with some missing `columns_to_download`?
# Some configuration for the `lazy index`. Sometimes, the index file is huge, so it's better
# not to load it all at once, but rather fetch in chunks.
lazy_index_chunk_size=1000, # Chunk size to fetch.
lazy_index_num_threads=2, # In how many threads to fetch the index chunks.
lazy_index_prefetch_factor=3, # How many chunks to prefetch ahead of time.
min_num_pending_tasks_thresh=200, # How many downloading samples should there be pending, before we start scheduling for sample downloading the next index chunk.
# Shuffle seed to use. If None, no shuffling is performed.
# If -1, then a random seed will be created on the first rank and distributed across all the ranks.
shuffle_seed=123,
# Let's filter out the images without any related images.
# The SQL query is quite arbitrary and applied on each index chunk before scheduling downloading tasks.
sql_query="SELECT * FROM dataframe WHERE num_related_images > 0",
# You can limit the dataset size artifically for debugging purposes (i.e. trying to overfit).
# Note: to reduce the dataset size below the batch_gpu size, you need to also set unaligned_worker_index=True and infinite_iteration=True.
max_size=50_000,
# unaligned_worker_index: bool = False, # Shall each worker iterate over the global dataset independently? Bad design, but helpful for tiny datasets.
# infinite_iteration: bool = False, # If True, the dataset would be iterated infinitely. Useful when you for some reason have batch_size > dataset_size and drop_last=True.
# Enable these flags for debugging purposes. Otherwise, the exceptions will be silenced.
print_exceptions=True,
print_traceback=True,
)
# Note: we have (slow) random access!
sample = dataset[12345]
Image.fromarray(sample['image'].permute(1, 2, 0).cpu().numpy().astype('uint8')).save('/tmp/where/to/download/debug-sample-12345.png')
dataloader = torch.utils.data.DataLoader(dataset, batch_size=3, num_workers=2, pin_memory=True, drop_last=True)
data_iterator = iter(dataloader)
for epoch in range(1, 2):
for i in range(7):
batch = next(data_iterator)
print(f'Processing batch #{i + 1} with {len(batch["media_id"])} items.')
for i, img in enumerate(batch['image']):
assert img.dtype == torch.uint8, f'Image dtype is not uint8, but {img.dtype}'
img = img.permute(1, 2, 0).cpu().numpy().astype('uint8') # [h, w, c]
Image.fromarray(img).save(os.path.join('/tmp/where/to/download', batch['media_id'][i] + '.png'))
if __name__ == '__main__':
main()For this example, we are just iterating over the text files in a local folder.
The nasty part of the folder datasets is the necessity to specify a data_type.
import os
import torch
from sds.dataset import StreamingDataset
src = '/tmp/dummy-data' # Could, in fact, be an S3 path as well (though harder to generate the data for)
dst = '/tmp/dummy-out'
EXT = 'txt'
# Generate some dummy data.
os.makedirs(src, exist_ok=True) # Let's generate some dummy data first.
for i in range(10):
with open(os.path.join(src, f'{i:05d}.{EXT}'), 'w') as f:
f.write(f'This is sample {i}.\n' * 10)
class LoadTransform():
def __call__(self, sample: dict) -> dict:
# We got a sample with keys: `index` (filename) and `txt` (path to the text file, inferred from the extension).
with open(sample[EXT], 'r') as f:
sample[EXT] = f.read() # Load the text content.
return sample
# For folder datasets, `columns_to_download` specify the sample data to copy (from local or S3).
dataset = StreamingDataset(src, dst, columns_to_download=[EXT], data_type='text', transforms=[LoadTransform()])
dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, num_workers=0)
for i, batch in enumerate(dataloader):
print(f'Batch {i}', batch)- Set
num_workers=0in the dataloader. - Set
print_exceptions=Trueandprint_traceback=Truein the dataset to see what is going wrong. - Set
SDS_LOG_LEVEL=DEBUGin the env to see detailed internal logs. - Make sure that you are using the latest version of the package: new versions can fix old bugs.
- Make sure that you are not using the latest version of the package: new versions can introduce new bugs.
- Make sure you are on the right version: new versions can break backward compatibility.
- If the dataloader is slow, try varying 1)
prefetch_factorandnum_workersin the dataloader; 2)prefetchandnum_downloading_workersin the dataset; and 3) varying lazy indexing parameters.
The entry point is the StreamingDataset class, which takes a source src and arguments and does the following:
- It constructs an index from the source:
- if
srcis a local or remote CSV/parquet/json file, it reads the index from there. - if
srcis a local or remote directory, it scans the directory and constructs an index from the files. - if
srcis a local or remote index wildcard (e.g.,/path/to/*.csv), it scans the files matching the wildcard and constructs an index from them.
- if
- Then, if it's a lazy index (controlled via
lazy_index_chunk_size), we broadcast the index metadata to each rank. If not lazy, we save the index on the node as a parquet file (for memory efficient chunked reading) and return an index metadata object. - After that, the dataset init initializes an "empty" downloader (without initializing the downloader workers). Without workers, the downloader can be used for random access queries, such as
dataset[0], which will download the sample with blocking. - When the iterator is created, we initialize the downloader workers, which will download samples in parallel and cache them on disk.
- Then, each dataloading worker reads its index chunk, shuffles it and starts the generator, which yields samples one by one. If it's a lazy index, then each worker would be reading index chunks with some prefetching (via downloading threads).
- The downloader yields the indicies of the downloaded samples. Then, we look up the sample metadata by its index and process it through a sequence of sample processing callbacks (named transforms).
- Sample transforms are very flexible and you can incorporate any processing logic you want, such as decoding images/videos/audios, applying augmentations, etc.
- The are "presets" of sample transforms which should cover 80% of the cases for image/video/text-to-video/etc use cases.
- Caching and eviction logic is performed by the StreamingDataset class, which keeps track of downloaded file sizes and evicts the oldest ones when the cache size exceeds the threshold. Currently, the cache size is set naively per workers as
node_cache_size / num_workers, assuming that each worker has equal load.
- Index construction
- Dataset iterator
- Shuffling each epoch
- Lazy index so that we can efficiently initialize large datasets on a single node without hitting disk space limits
- Cache data + evict cold samples
- Video decoding
- Audio loading
- Resumal logic. Only if the number of ranks is not changed, since otherwise, we will have shuffling discrepancies.
- There is no global shuffling right now, so smth like ImageNet training will be flawed.
- Remove logging calls from the codebase.
- Clean broken samples from disk.
- Time-based garbage collection.
- Get/load state dict and make sure we resume from it.
- Can we construct a remote S3 index in parallel?
- Construct an index for a local/remote directory.
- Sometimes, we can have less raw index files that nodes.
- Missing fields should be populated in the dataloader or index meta or where? (I guess, they should automatically be filled with
Nonein the index). - Re-slice indicies based on sample counts and number of nodes.
- VAE latents loading.
- An option for interleaved indexing.
- Re-opening iter for multi-stream dataloader would break the synchronization of stream types.
- Lazy index does not work with sample_in_epoch.
- We shouldn't need to reset the downloader after each iter_slice finish...
- For lazy index, schedule next index chunk before the current one is finished.
- Make MultiStreamDataLoader robust to re-opening the iterator.
- Docker image.
- Mixing between streams across ranks.
- BQ script with exportion into a single parquet file.
- Video latents loading.
- Fixed random order.
- Consecutive interleaved order.
- Put on our corp pypi.
- Evict samples inside random access queries as well.
- Our caching logic is broken: we think we've downloaded a sample and occupied some disk space, but it was already there. This makes us delete samples thinking that we need to free up space.
- Support shuffle_seed = -1.
- Audio normalization.
- Row group size = 20k for the new script.
- Tutorial/usage examples
- Video + .wav files loading (now we only support video files with embedded audio).
- Pre-download transforms.
- Documentation
- Tensor parallel support: iterating the streams from one dataloader for one meta-iter and broadcasting them within the group.
- For non-lazy parquet index without slicing and filtering, we don't need to reload-resave it.
- Fix the current unit tests.
-
pre_download_transformsto first select a caption embedding, then downloading the selected one for traffic optimization. - How can we reweight the index during training? A straightforward way would be randomly filtering out samples in the index via SQL queries. But maybe, we can have a reweighting_fn as an input or a weight column in the index?
- Support spawn start method for dataloader workers.
- An option to cache the downloaded/loaded sample dict? Ideally, through some cache transform, i.e. so we can cache at any selected point in the transform chain. Then, we can store videos unpacked as np4/torch files and load them much faster.
-
sds.utils.data_utils.read_parquet_sliceis not working for a wildcard of parquets. - Make the scripts/construct_index_from_bq.py runnable from command line. I.e. we can have an
sdscommand line tool with subcommands.
- Is it possible to make
construct_index_from_bq_small.pywork for large tables? It's logic is much cleaner... - Fix random seeds in transforms. Possibly, by adding a
__random_seed__field? Or would fixing a global random seed be enough? - The logic for resetting the downloader after each epoch is hacky. I dont think we should do that.
- More test coverage: state dict, resumal, index construction, deadlocks, sync-ed dataloader, etc.
- Shutdown for num_workers > 0 is quite slow. Not sure why.
- Recompute sample_in_epoch based on the number of workers. I.e. sample_in_local_epoch => sample_in_global_epoch.
- @beartype for streaming dataset init method.
- Allow empty columns_to_download (i.e., only metadata).
- Refresh index cache when restarting the dataloader? I.e. at least check the new size...
- Support for data providers as callbacks (possibly via forward/backward translation)
- Cache index: save to "cache" on S3 and load from cache (if present). Basically, if we are given a folder or split_file_paths.txt or *.csv, then we could save the index normally (though we should be careful about single-node vs multi-node cases).
- Deterministic order for the thread pool downloader.
- Some race conditions might happen, when someone is evicting/downloading a sample with a downloader, while someone else is doing this via random access, since random access breaks the non-overlapping assumption. Also, we don't free the disk space used by random access samples. We should probably lock the downloader (among all the node workers?!) during random access queries.
- How to support multiple instances of the same dataset in a single process? That might lead to race conditions in downloading/eviction.
- We likely also need some node-level file lock to keep disk usage information for caching, since each new iterator instance is thinking that it's starting from scratch.
- Plot the job cost in terms of downloading, given the number of requests, and bytes downloaded.
- We can download video chunks from S3 given the random offset/num frames we need.
- Fix TODOs in the codebase (i.e. grep for "TODO" and fix).
- SQLite index instead of parquet.
- Move synchronous batch-wise yielding to the
StreamingDatasetclass using the round-robin assumption of torch dataloader iterating over workers.
PYTHONPATH=. pytest testsCreate your own branch, make changes, and create a pull request.
Style guide:
- Use rebase instead of merges where possible.
- 4 spaces for indentation
- Always annotate shapes of tensors via inline comments (even in the obvious cases), e.g.:
x = torch.randn(3, 4, 5) # [batch_size, sequence_length, hidden_size]- Always annotate the type of the function arguments and return values.
- We use something similar to AngularJS commit styleguide: a commit should be of the form
<type>(<scope>): <subject>, where<type>is one of the following:feat: a new featurefix: a bug fixdocs: changes to documentationstyle: formatting, missing semi colons, etc; no code changerefactor: refactoring production codetest: adding tests, refactoring test; no production code changechore: updating build tasks, package manager configs, etc; no production code changerevert: reverting to a previous commitperf: a code change that improves performanceci: changes to CI configuration files and scriptsbuild: changes that affect the build system or external dependenciestemp: temporary commit that won't be included in the final versionwip: work in progress