EleutherAI / pythia

The hub for EleutherAI's work on interpretability and learning dynamics
Apache License 2.0
2.16k stars 156 forks source link

[Pythia on Pile-Dedup] Training for ~1.5 epochs: how to identify the repeated sequences (i.e., the additional .5 epoch)? #144

Open pietrolesci opened 6 months ago

pietrolesci commented 6 months ago

Hi there,

The deduplicated dataset has fewer sequences and to keep a consistent token count with the non-deduplicated version the models are trained for ~1.5 epochs (as discussed in the README). Between epochs, are the data reshuffled or simply the dataloader starts from the beginning again in the same order? If the latter is the case, is there a way to know exactly which checkpoint is the first to see the same data twice? Put differently, is there a way to know which sequences are seen by the model in the additional ~half epoch?

Thanks a lot in advance for your help!

cc @haileyschoelkopf

jeffreygwang commented 6 months ago

Hey! I had similar questions a while back for a paper in which we used the Pythia suite—to the best of my understanding, the answers are that it's 1.5 epochs, where about the first half of the data (same order) is seen twice. The Pythia paper describes how many total tokens the models see and how many it sees in the first pass; based on those numbers, I use the step98000 checkpoint as my full "single pass" checkpoint. I believe the checkpoints after start "seeing double."

pietrolesci commented 6 months ago

Thanks a lot for your answer @jeffreygwang, this seems reasonable to me too!

pietrolesci commented 6 months ago

Between epochs, are the data reshuffled or simply the dataloader starts from the beginning again in the same order?

The answer seems to be that the dataloader does NOT simply start from the beginning again. It means that the concatenation happened at the document level, that is before the "packing" process. This means the initial tokens can appear in different positions within a sequence.