Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.56k stars 3.4k forks source link

on_epoch_end callback is called before on_validation_epoch_end #6480

Closed dumitrescustefan closed 3 years ago

dumitrescustefan commented 3 years ago

🐛 Bug

The on_epoch_end is called before the epoch ends.

What I'm doing:

Up until I upgraded from 1.0.8 (directly now to 1.2.2 and 1.2.3 today) everything was working fine. The validation_epoch_end was logging metrics and in the callback I read them fine. Now, I'm getting :

Epoch 0:  79%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                 | 42/53 [00:08<00:02,  4.80it/s, loss=2.85, v_num=2]
Traceback (most recent call last):
  File "cube3/trainer.py", line 279, in <module>
    trainer_object.fit()
  File "cube3/trainer.py", line 233, in fit
    trainer.fit(model, train_loader, val_loader)
  File "/home/echo/p3.8-test/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 514, in fit
    self.dispatch()
  File "/home/echo/p3.8-test/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 554, in dispatch
    self.accelerator.start_training(self)
  File "/home/echo/p3.8-test/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 74, in start_training
    self.training_type_plugin.start_training(trainer)
  File "/home/echo/p3.8-test/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 111, in start_training
    self._results = trainer.run_train()
  File "/home/echo/p3.8-test/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 645, in run_train
    self.train_loop.run_training_epoch()
  File "/home/echo/p3.8-test/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py", line 558, in run_training_epoch
    self.run_on_epoch_end_hook(epoch_output)
  File "/home/echo/p3.8-test/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py", line 806, in run_on_epoch_end_hook
    self.trainer.call_hook('on_epoch_end')
  File "/home/echo/p3.8-test/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1102, in call_hook
    trainer_hook(*args, **kwargs)
  File "/home/echo/p3.8-test/lib/python3.8/site-packages/pytorch_lightning/trainer/callback_hook.py", line 115, in on_epoch_end
    callback.on_epoch_end(self, self.lightning_module)
  File "./cube3/networks/lemmatizer.py", line 278, in on_epoch_end
    acc = metrics["meta_acc"]
KeyError: 'meta_acc'

This is because the metrics in my code which is actually just the trainer.callback_metrics is an empty dict. Furthermore, this fails in epoch 0, after the sanity check ,which finishes fine (I printed the metrics and in prints the 0 accuracy I expected from the sanity check).

What I tried is to switch the on_epoch_end with on_validation_epoch_end and it works. This led me to the conclusion that since on_epoch_end returns an empty dict while on_validation_epoch_end returns a filled-in dict from the validation_epoch_end in the pl module, the on_epoch_end is called in an incorrect order.

Again, this worked well with 1.0.8. (don't know in what latter version this behaviour changed).

Expected behavior

on_epoch_end should have the metrics from validation_epoch_end

Environment

rohitgr7 commented 3 years ago

the order for the calls is:

on_validation_epoch_end
on_epoch_end
on_validation_end

on_validation_end is the last one so I'd suggest don't use this one to log anything since some callbacks like ModelCheckpoint/EarlyStopping use this hook and expects everything to be present there if you want to monitor something.

Also on_epoch_end is called after every train/eval/test epoch end so I'd suggest use on_validation_epoch_end in your use-case.

Also, I don't think logging is supported in on_validation_end hook.

edenlightning commented 3 years ago

Feel free to reopen if you have any issues!

dumitrescustefan commented 3 years ago

Thanks @rohitgr7 for the clarification.

Maybe you could help me out with a pointer on the following:

Since I shouldn't do any logging in the validation_epoch_end which is in the pl module itself, and considering I need an early stopping that is based on some processing that can't be done in validation_step (as it needs to aggregate all the results in the steps and then do some other operations before writing a meta variable that the early stopping callback reads), then what would be the best way to achieve this?

Should I return all pieces in validation_step and then process them in the callback in on_validation_epoch_end and then write my custom checkpoint saving in the pl module on_validation_epoch_end ? It seems pretty convoluted to me :(

I mean I can redo the whole thing without any pl callbacks/hooks by dumping all processing in validation_epoch_end and computing there my meta accuracy variable and stopping the training there + saving best model, but that was the reason to use lightning, to save me from most of this boring stuff and clean up the code :)

Furthermore, in https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#hooks I see:

def val_loop():
    model.eval()
    torch.set_grad_enabled(False)

    on_validation_epoch_start()
    val_outs = []
    for val_batch in val_dataloader():
        on_validation_batch_start()

        # -------- val step methods -------
        out = validation_step(val_batch)
        val_outs.append(out)

        on_validation_batch_end(out)

    validation_epoch_end(val_outs)
    on_validation_epoch_end()

    # set up for train
    model.train()
    torch.set_grad_enabled(True)

which tells me that on_validation_epoch_end callback is called after the validation_epoch_end :confused: And again, more confusion generated by the fact that it was working ok up until 1.0.8 (or maybe later, did not try), and also passes successfully the validation sanity check/

rohitgr7 commented 3 years ago

hi @dumitrescustefan my bad. I misread validation_epoch_end with on_validation_end. What you are doing should work, but looks like it isn't working. Mind share a reproducible example with https://colab.research.google.com/drive/1HvWVVTK8j2Nj52qU4Q4YCyzOm0_aLQF3?usp=sharing. reopening the issue.

rohitgr7 commented 3 years ago

ok, I see what's happening here. the problem is with the on_epoch_end hook. This hook is called after every loop irrespective of train/eval/test. When it's called at the end of training_loop, it expects val/meta_acc but val loop is called after this, and since val/meta_acc is not logged yet it's raising an error here. So all you need is to do the val-related stuff inside the on_validation_epoch_end hook.

dumitrescustefan commented 3 years ago

Thanks @rohitgr7, I switched to on_validation_epoch_end.

IMHO it should be noted in the docs that on_epoch_end refers here to the end of any of the train/val/test "epochs"; for me, an epoch means a train, then an optional validation and finally another optional test part put together -> thus, on_epoch_end means run after train&val&test, not after each part. Anyway, got that cleared out, thanks a lot for your time, closing now.

brynhayder commented 2 months ago

Has the order in which these methods are called been changed? I can't see any indication of what is called when in the docs. I've had to look through the code.