Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.24k stars 3.38k forks source link

Incorrect batch progress saved in checkpoint at every_n_train_steps #18060

Open shuaitang5 opened 1 year ago

shuaitang5 commented 1 year ago

Bug description

When saving a checkpoint at every_n_train_steps=3, it performs the checkpoint saving inside on_train_batch_end function in ModelCheckpoint class. During that checkpoint saving, the state dict of fit loop is snapshotted and saved, along with the batch progress of it. But the batch_progress is only incremented after on_train_batch_end is called/checkpoint is saved, thus the saved checkpoint having incorrect batch_progress which looks like this:

# in checkpoint file checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']
{total: {ready: 3, completed: 2, started: 3, processed: 3}}

And the expected value should be: {total: {ready: 3, completed: 3, started: 3, processed: 3}}, which is what the checkpoint saved after validation contains.

This causes an issue that when we resume from batch_end checkpoint, the starting batch_idx is 2 while the global step is 3 in training_step function in model module (they should match), and following saved checkpoint all having incorrect step value in file name. This doesn't seem like expected behavior, am I missing something?

I'm currently using a hack in the on_train_batch_end override function like this to overcome this issue:

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx) -> None:
        # hack: https://github.com/Lightning-AI/lightning/blob/master/src/lightning/pytorch/loops/training_epoch_loop.py#L233-L237
        # At the time this function is called, the `completed` value in batch progress is not incremented yet.
        # If a checkpoint is saved, the saved checkpoint will have an incorrect completed value in batch progress.
        # When we resume from this checkpoint, it will cause batch_idx becoming one step behind global step value in training_step func in modelModule
        trainer.fit_loop.epoch_loop.batch_progress.increment_completed()
        super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)

        # revert back changes to completed value in batch progress
        trainer.fit_loop.epoch_loop.batch_progress.total.completed -= 1
        trainer.fit_loop.epoch_loop.batch_progress.current.completed -= 1

What version are you seeing the problem on?

v1.9, master

How to reproduce the bug

No response

Error messages and logs

# Error messages and logs here please

Environment

Current environment ``` #- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow): #- PyTorch Lightning Version (e.g., 1.5.0): #- Lightning App Version (e.g., 0.5.2): #- PyTorch Version (e.g., 2.0): #- Python version (e.g., 3.9): #- OS (e.g., Linux): #- CUDA/cuDNN version: #- GPU models and configuration: #- How you installed Lightning(`conda`, `pip`, source): #- Running environment of LightningApp (e.g. local, cloud): ```

More info

No response

cc @carmocca @justusschock

stale[bot] commented 1 year ago

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions - the Lightning Team!

ordabayevy commented 10 months ago

I came across this issue as well. Is there a solution to it?

heth27 commented 6 months ago

I came across this issue as well. Is there a solution to it?

you can change the checkpoint then its being saved, for example tanking the processed value, so it reruns the batch if the optimizer step does not complete correctly


        checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['total']['completed'] = \
            checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['total']['processed']
        checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['completed'] = \
            checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['processed']
        checkpoint['loops']['fit_loop']['epoch_loop.state_dict']['_batches_that_stepped'] = \
            checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['processed']```
docbeaker commented 3 months ago

Definitely still an issue and definitely still open.

iamlockelightning commented 1 week ago

Same problem. Hope to see an appropriate solution.