Closed thayes427 closed 2 months ago
Most dataloaders implement some form of prefetching on samples. Accordingly, streaming can't track the number of samples processed, as it may have yielded samples which have not actually been processed when checkpointing occurs. Unfortunately, this means you must pass in sample count during a call to state_dict
.
Thanks for the explanation! Makes sense to me.
My question is, why does StreamingDataset.state_dict() have to be told how many samples have been yielded? https://github.com/mosaicml/streaming/blob/0fc72665e9c20e1c53d036d09592b7335ae8ae9a/streaming/base/dataset.py#L771
The pattern of the StreamingDataloader keeping track of
num_samples
and providing this to the dataset to get the state dict works for most use cases, but things get much more complicated if you have downstream functions like rejection sampling and bin packing that make it harder to keep track of how many samples were yielded from the dataset.So, I just wonder if there's any reason why the StreamingDataset itself can't keep track of how many samples it has yielded? It seems that would enable a simpler state saving and loading interface.
Thanks in advance for any help!