Open pimdh opened 6 months ago
@pimdh Thank you for already investigating this.
Since the training loop is quite complex, I can't say for sure this is the right solution but it sounds reasonable. Would you be interested to send a PR with this change? We can then let the full test suite run on your PR and see if there are any edge cases with this. If it works I can help add a test case.
Hi @awaelchli , I've filed the PR at #19583. While this suffices in my usecase, unfortunately, I won't have time to add unit tests to validate this. Thanks
The same bug also happens when resuming training after being trained on IterableDataset
Bug description
When resuming from a mid-epoch checkpoint (which I have to use as my dataset is large), the training loop runs a validation loop for only one iteration, which leads to wrong validation loss logged.
It appears like the
batch_progress
oflighting.pytorch.loops._EvaluationLoop
wrongly gets filled from the checkpoint as if the validation loop was already done, and not properly reset after the checkpoint is loaded.What version are you seeing the problem on?
v2.2
How to reproduce the bug
Error messages and logs
Environment
Current environment
``` * CUDA: - GPU: - NVIDIA GeForce RTX 2080 Ti - available: True - version: 11.7 * Lightning: - lightning: 2.2.0.post0 - lightning-utilities: 0.10.1 - pytorch-lightning: 2.2.0.post0 - torch: 2.0.1 - torch-ema: 0.3 - torch-geometric: 2.5.0 - torch-scatter: 2.1.2+pt20cu117 - torchmetrics: 1.3.1 - torchvision: 0.15.2 * System: - OS: Linux - architecture: - 64bit - processor: x86_64 - python: 3.10.13 - release: 5.4.0-152-generic ```More info
A fix/workaround for this issue, is to add
self.batch_progress.reset_on_run()
at the end of_EvaluationLoop.run
.cc @carmocca @justusschock