Open stas00 opened 1 year ago
In "can cause unreliable results" are you perhaps implying that there is no guarantee the DL will not continue from where it left off on saving the last checkpoint but will repeat the same data? Shouldn't PTL save the worker RNG state and correctly restore it on resume? Though we custom DL there is no way PTL could easily do that.
Yes exactly, that's what the warning is trying to say. I often struggle to explain to users that resuming mid-epoch is highly non-trivial. It seems so obvious that Lightning should "support" this to many users, or users even assume that Lightning already does this without even questioning what might happen with the random state. The surprise could be that the results are skewed due to the network seeing some data more often than others due to the restart.
We spent quite some time (two entire releases) developing a fault-tolerant system but it never came out of the experimental state because of several challenges. Capturing the random state in workers was possible, but very costly at the same time and had a load of edge cases to handle. Lots of caveats around IterableDataset. Even with a limited scope, the complex situation became unmanageable. We ultimately decided to drop the effort of making dataloaders stateful and resumable, and instead only handle the loop state and trainer state. We hoped that eventually DataLoader2 / torchdata would put the necessary building blocks in place to make data pipes serializable, but now that they stopped the development this won't be possible. For now, we say that Lightning can guarantee that the trainer and loop state is managed in a fault-tolerant way, but the data is not and is up to the user.
Back to the warning: For this warning
You're resuming from a checkpoint that ended before the epoch ended. This can cause unreliable results if further training is done. Consider using an end-of-epoch checkpoint
I was thinking that we already have an info message when resuming a checkpoint:
Restored all states from the checkpoint at checkpoints/epoch=0-step=500.ckpt
and we could possibly mention that the checkpoint is a mid-epoch checkpoint in that message, essentially combining the warning into it. For example:
Restored all states from the mid-epoch checkpoint at checkpoints/epoch=0-step=500.ckpt DataLoader sampler state will be reset
(I don't know yet how to best word this in the message, this is just a quick draft)
cc @carmocca
OK, so this is actually not just an advisory warning that is relevant to some, this is a problem that needs to be made users aware of by all means. In such case I'd raise an exception on resume and explain what happens and give the user an API to tell I-know-what-I-am-doing flag to remove the exception. The current warning will be ignored by many (most?) and unexpected problems will follow since most users will blindly believe that PL had it figured out for them. This is too important to not stop the presses, IMHO.
Somehow I thought we had it right in the HF Trainer wrt RNG restoration in DL, but it's been a long time since I worked on it, so my memory is foggy. Perhaps at that time we really only dealt with generic situations. But I totally hear you that in some cases it's easy whereas in other cases the user should be taking care of it.
Perhaps:
and as I suggested in the first para - I think this situation should warrant an exception raising.
These are all good suggestions. The warning is really old and should be updated to describe these limitations. AFAIK the main missing pieces would be to implement https://github.com/Lightning-AI/lightning/issues/17105 and implement loading/reloading of self.log
ged metrics (this was partially implemented before https://github.com/Lightning-AI/lightning/pull/16516)
Does the checkpoint save the number of batches that were seen in the current epoch? thinking about how to resume from an inside epoch ckpt and think one could just iterate through all the batches until batch_index > saved_index
btw fairseq has this capability built in
Yes, that's what HF Trainer does and if I remember correctly Megatron-LM does as well.
But this only works well if you have a simple DataSampler - ideally already preprocessed - if you use a complicated one that requires a lot of real time processing such fast-forwarding could be extremely slow. So probably need to disable any transformations for such an action.
Additionally if the dataset is remote and webdataset or alike DL is used this again isn't quite doable, since you will have to potentially re-download many chunks of data from remote storage.
In these complicated cases keeping track of RNG states and restoring those is a better solution. Albeit remote storage handling can still be a problem.
In RETURNN, we also have this capability. More specifically, we operate on sub-epochs. The user specifies the number of random partitions of a dataset. E.g. for Librispeech, we use 20, so each sub-epoch covers around 100h of audio. Once a full epoch is finished, the partitioning is redone.
In our case, we shuffle the sequences in advance for the whole dataset for each whole epoch, and then partition it evenly into the sub-epochs. This approach might not scale well for very large corpora though, as you need to operate on the list of sequences after every epoch, which might be too large to handle. (For all our research corpora, it was not a problem so far.)
All our dataset logic, also including any on-the-fly processing (e.g. augmentation etc) use a shared RNG, which is seeded at the beginning of every sub-epoch. This assures that we can safely recover after a sub-epoch.
Shuffling the sequences can also be done on-the-fly, so I think this approach can still scale to even much larger corpora.
Maybe such an approach could be interesting here as well.
But if you can properly serialize the RNG state of any data loader iterator and any other data sampler in between, then you can also recover the state after every sequence or mini batch. The approach in RETURNN does not need to serialize the RNG state, though, so it's a bit simpler to implement.
Description & Motivation
forking from https://github.com/Lightning-AI/lightning/issues/18723#issuecomment-1751307472 where we were discussing various warnings that don't necessarily apply to all.
This issue discusses this warnings:
Many shared SLURM environments have a relatively short time limit to each job so one can't do one epoch w/o restart and resuming. e.g. some have only 20h top.
In "can cause unreliable results" are you perhaps implying that there is no guarantee the DL will not continue from where it left off on saving the last checkpoint but will repeat the same data? Shouldn't PTL save the worker RNG state and correctly restore it on resume? Though with a custom DL there is no way PTL could easily do that.
But in general a 3 months training will take many restarts, not only because of a short SLURM job limit, but also because there will be divergences requiring rollbacks, which means restarts.
And yes the operator needs to be super-aware whether the resume breaks the unique flow of samples and leads to replacements.
cc @borda @justusschock @awaelchli