mosaicml / streaming

A Data Streaming Library for Efficient Neural Network Training
https://streaming.docs.mosaicml.com
Apache License 2.0
1.11k stars 138 forks source link

Why does StreamingDataset.state_dict() have to be told how many samples have been yielded? #763

Closed thayes427 closed 2 months ago

thayes427 commented 2 months ago

My question is, why does StreamingDataset.state_dict() have to be told how many samples have been yielded? https://github.com/mosaicml/streaming/blob/0fc72665e9c20e1c53d036d09592b7335ae8ae9a/streaming/base/dataset.py#L771

The pattern of the StreamingDataloader keeping track of num_samples and providing this to the dataset to get the state dict works for most use cases, but things get much more complicated if you have downstream functions like rejection sampling and bin packing that make it harder to keep track of how many samples were yielded from the dataset.

So, I just wonder if there's any reason why the StreamingDataset itself can't keep track of how many samples it has yielded? It seems that would enable a simpler state saving and loading interface.

Thanks in advance for any help!

mvpatel2000 commented 2 months ago

Most dataloaders implement some form of prefetching on samples. Accordingly, streaming can't track the number of samples processed, as it may have yielded samples which have not actually been processed when checkpointing occurs. Unfortunately, this means you must pass in sample count during a call to state_dict.

thayes427 commented 2 months ago

Thanks for the explanation! Makes sense to me.