Lightning-AI / litdata

Transform datasets at scale. Optimize datasets for fast AI model training.
Apache License 2.0
337 stars 38 forks source link

Use different batch sizes in CombinedStreamingDataset #327

Open schopra8 opened 1 month ago

schopra8 commented 1 month ago

🚀 Feature

CombinedStreamingDataset allows you to combine multiple StreamingDatasets with a sampling ratio -- but it assumes that that the batch_size is the same for each dataset.

Motivation

If the different datasets have tensors of different sizes, it would be great to use different batch sizes per dataset to maximize throughput / memory consumption (e.g. batch size of 1 for dataset with larger input tensors, batch size of 2 for dataset with smaller input tensors).

Pitch

Allow set_batch_size to take a list of batch_sizes -- one per dataset.

Alternatives

One thing that that would need to be considered would be gradient accumulation. For example, if dataset A is large tensors, with only 1 fitting in memory per batch and dataset B has small tensors, with 4 fitting in memory per batch, you would want to do 4 steps of gradient accumulation when acting on samples from dataset A if you want a 50-50 split during training between dataset A and dataset B. If you want a different ratio samples from dataset A vs. dataset B, you'd need to be able to make this number of gradient accumulation steps configurable.

Additional context

tchaton commented 1 month ago

Hey @schopra8. Feel free to make a contribution. The main challenge will be to ensure fault tolerance works properly.