mosaicml / streaming

A Data Streaming Library for Efficient Neural Network Training
https://streaming.docs.mosaicml.com
Apache License 2.0
1.12k stars 141 forks source link

Degraded shuffle quality near the end of an epoch #796

Closed thayes427 closed 3 weeks ago

thayes427 commented 1 month ago

Hi! My team has been using streaming datasets for many different experiments, and we consistently observe that shuffle quality can dramatically degrade near the end of an epoch.

Here are the settings we're using:

Do you have any recommendation for different settings that can improve the shuffle quality? Thank you!

karan6181 commented 3 weeks ago

@thayes427 I would recommend not touching predownload and use shuffle_algo as py1e which is a default. Rest other parameters, you can try using streaming simulator.

@snarayan21 Anything you see obvious here?

snarayan21 commented 3 weeks ago

Hey @thayes427, thanks for filing this and apologies for not getting to it earlier.

Couple of questions:

Our strongest shuffling algorithms are py1e, py1br, and py1b. The shuffle strength of these algorithms is determined by num_canonical_nodes*shuffle_block_size. See the shuffling page on our docs for some more info.

I would also recommend using the streaming simulator tool to help analyze different options for shuffle quality.

Hope this helps!

thayes427 commented 3 weeks ago

Thank you both for your helpful replies!

How are you measuring shuffle quality degradation?

We have implemented logging in our training loop for dataset source, and we see that the percentage of samples from each source starts to meander much more close to the end of the training epoch. We actually see this reflected in the training loss, since samples from different dataset sources can tend to be easier or harder on average.

What does your dataset look like? Are you mixing multiple data sources?

Yes, we are mixing several data sources, typically 6 to 10 sources. These datasets mostly consist of what is essentially text data with samples varying in length between datasets.

sampling granularity

I honestly don't remember why we set sampling_granularity to 10, need to check with my colleague on that. Makes sense to me to remove this setting.

Thank you for your recommendations! We'll try py1e and check out the streaming simulator tool.

snarayan21 commented 3 weeks ago

@thayes427 Yeah, I suspected you might have multiple data sources. I highly suggest you look into specifying the batching_method, especially the stratified one, since this will take the same number of samples from each stream in every single batch, deterministically. More info on dataset mixing & batching methods here.