mosaicml / streaming

A Data Streaming Library for Efficient Neural Network Training
https://streaming.docs.mosaicml.com
Apache License 2.0
1.1k stars 137 forks source link

Make `epoch_sample_ids` cachable #792

Open janEbert opened 2 weeks ago

janEbert commented 2 weeks ago

šŸš€ Feature Request

It would be awesome to enable caching of epoch_sample_ids.

Motivation

Caching would remove a lot of redundant work that currently is re-executed at each run. It takes 20 minutes for my dataset's sample IDs to be created. This wastes a lot of budget for large-scale runs.

In my case, I'll specifically focus on the implementation in https://github.com/mosaicml/streaming/blob/2e9db78db6dd4108b697cfde92a95cd0de80539c/streaming/base/batching/random.py. Specifically, what takes long is dataset.resample_streams (with sampling_method="balanced") and get_shuffle (with shuffle_algo="py1e" in my case).

[Optional] Implementation

I've looked into this a bit, but get_shuffle's indirect dependence through get_partitions on sample_in_epoch (drop_first in the called functions) seems to make this very difficult. Maybe someone with more knowledge of the codebase can chime in on this, though. I would personally be happy with a simple hacky solution for now. :)

For now I've implemented a stupid NumPy file hash for dataset.resample_streams, which already saves around 40ā€“50% of the time.

snarayan21 commented 3 days ago

Hey, @janEbert this seems sensible! We have chosen not to cache the epoch sample id tensor mainly because persistent storage may not be available in many training setups. So reading from a cached file is not always possible.

However, this could be an optional feature for users that do have this set up. Dumping the numpy tensor to a file honestly is a good start -- we'd be happy to help review an implementation, and always appreciate community PRs!

janEbert commented 2 days ago

I see, that makes sense. It also seemed like the indices are re-calculated upon each validation run, so there is really only a time save when you start a run or load from a checkpoint.

Regarding the implementation, I'll be happy to put what I cooked up into a PR once I find some free time. Considering the re-calculation I mentioned above (if I interpreted my logs correctly), maybe the additional complexity is not really worth to add to the code base, though. :)