I think there are two problems with multi-epoch training:
Training finishes if setting e.g. duration: 2e12T & 1 epoch < 2e12 tokens. It currently requires setting duration: 2ep but it should also work with T I think (also mentioned here: https://github.com/allenai/OLMo/issues/554)
olmo.train:816 INFO [step=817847/1430511]
train/CrossEntropyLoss=2.341
train/Perplexity=10.39
throughput/total_tokens=1,715,149,471,744
throughput/device/tokens_per_second=18,573
throughput/device/batches_per_second=0.5668
olmo.train:1172 INFO Training epoch complete
olmo.train:1194 INFO Saving final checkpoint...
train:238 INFO Training complete
🐛 Describe the bug
I think there are two problems with multi-epoch training:
duration: 2e12T
& 1 epoch < 2e12 tokens. It currently requires settingduration: 2ep
but it should also work withT
I think (also mentioned here: https://github.com/allenai/OLMo/issues/554) olmo.train:816 INFO [step=817847/1430511] train/CrossEntropyLoss=2.341 train/Perplexity=10.39 throughput/total_tokens=1,715,149,471,744 throughput/device/tokens_per_second=18,573 throughput/device/batches_per_second=0.5668 olmo.train:1172 INFO Training epoch complete olmo.train:1194 INFO Saving final checkpoint... train:238 INFO Training completeAfaict when resuming a run in >1 epoch state, it requires newly setting
epoch: num_epochs
in the config to ensure that the data is in a different order: https://github.com/allenai/OLMo/blob/e6430a070e97154c5b3b5fc35a7416026e9353ad/olmo/data/__init__.py#L103 I think we should just load this from the trainer state dict. However, afaict this is currently not happening because the checkpoint is only loaded after the IterableDataset is already created. I.e. data loader is loaded: https://github.com/allenai/OLMo/blob/e6430a070e97154c5b3b5fc35a7416026e9353ad/scripts/train.py#L116 Checkpoint with epoch value is loaded: https://github.com/allenai/OLMo/blob/e6430a070e97154c5b3b5fc35a7416026e9353ad/scripts/train.py#L238 & the data loader remains unchanged.Without knowing this, people will train the 2nd epoch with the same data order as the 1st.
Versions
latest main