Open shuaitang5 opened 1 year ago
This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions - the Lightning Team!
I came across this issue as well. Is there a solution to it?
I came across this issue as well. Is there a solution to it?
you can change the checkpoint then its being saved, for example tanking the processed value, so it reruns the batch if the optimizer step does not complete correctly
checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['total']['completed'] = \
checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['total']['processed']
checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['completed'] = \
checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['processed']
checkpoint['loops']['fit_loop']['epoch_loop.state_dict']['_batches_that_stepped'] = \
checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['processed']```
Definitely still an issue and definitely still open.
Same problem. Hope to see an appropriate solution.
Bug description
When saving a checkpoint at
every_n_train_steps=3
, it performs the checkpoint saving inside on_train_batch_end function inModelCheckpoint
class. During that checkpoint saving, the state dict of fit loop is snapshotted and saved, along with the batch progress of it. But the batch_progress is only incremented after on_train_batch_end is called/checkpoint is saved, thus the saved checkpoint having incorrect batch_progress which looks like this:And the expected value should be:
{total: {ready: 3, completed: 3, started: 3, processed: 3}}
, which is what the checkpoint saved after validation contains.This causes an issue that when we resume from batch_end checkpoint, the starting batch_idx is 2 while the global step is 3 in training_step function in model module (they should match), and following saved checkpoint all having incorrect step value in file name. This doesn't seem like expected behavior, am I missing something?
I'm currently using a hack in the on_train_batch_end override function like this to overcome this issue:
What version are you seeing the problem on?
v1.9, master
How to reproduce the bug
No response
Error messages and logs
Environment
Current environment
``` #- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow): #- PyTorch Lightning Version (e.g., 1.5.0): #- Lightning App Version (e.g., 0.5.2): #- PyTorch Version (e.g., 2.0): #- Python version (e.g., 3.9): #- OS (e.g., Linux): #- CUDA/cuDNN version: #- GPU models and configuration: #- How you installed Lightning(`conda`, `pip`, source): #- Running environment of LightningApp (e.g. local, cloud): ```More info
No response
cc @carmocca @justusschock