Closed dumitrescustefan closed 3 years ago
the order for the calls is:
on_validation_epoch_end
on_epoch_end
on_validation_end
on_validation_end
is the last one so I'd suggest don't use this one to log anything since some callbacks like ModelCheckpoint
/EarlyStopping
use this hook and expects everything to be present there if you want to monitor something.
Also on_epoch_end
is called after every train/eval/test epoch end so I'd suggest use on_validation_epoch_end
in your use-case.
Also, I don't think logging is supported in on_validation_end
hook.
Feel free to reopen if you have any issues!
Thanks @rohitgr7 for the clarification.
Maybe you could help me out with a pointer on the following:
Since I shouldn't do any logging in the validation_epoch_end
which is in the pl module itself, and considering I need an early stopping that is based on some processing that can't be done in validation_step
(as it needs to aggregate all the results in the steps and then do some other operations before writing a meta variable that the early stopping callback reads), then what would be the best way to achieve this?
Should I return all pieces in validation_step
and then process them in the callback in on_validation_epoch_end
and then write my custom checkpoint saving in the pl module on_validation_epoch_end
? It seems pretty convoluted to me :(
I mean I can redo the whole thing without any pl callbacks/hooks by dumping all processing in validation_epoch_end
and computing there my meta accuracy variable and stopping the training there + saving best model, but that was the reason to use lightning, to save me from most of this boring stuff and clean up the code :)
Furthermore, in https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#hooks I see:
def val_loop():
model.eval()
torch.set_grad_enabled(False)
on_validation_epoch_start()
val_outs = []
for val_batch in val_dataloader():
on_validation_batch_start()
# -------- val step methods -------
out = validation_step(val_batch)
val_outs.append(out)
on_validation_batch_end(out)
validation_epoch_end(val_outs)
on_validation_epoch_end()
# set up for train
model.train()
torch.set_grad_enabled(True)
which tells me that on_validation_epoch_end
callback is called after the validation_epoch_end
:confused:
And again, more confusion generated by the fact that it was working ok up until 1.0.8 (or maybe later, did not try), and also passes successfully the validation sanity check/
hi @dumitrescustefan my bad. I misread validation_epoch_end
with on_validation_end
. What you are doing should work, but looks like it isn't working. Mind share a reproducible example with https://colab.research.google.com/drive/1HvWVVTK8j2Nj52qU4Q4YCyzOm0_aLQF3?usp=sharing.
reopening the issue.
ok, I see what's happening here.
the problem is with the on_epoch_end
hook. This hook is called after every loop irrespective of train/eval/test. When it's called at the end of training_loop, it expects val/meta_acc
but val loop is called after this, and since val/meta_acc
is not logged yet it's raising an error here. So all you need is to do the val-related stuff inside the on_validation_epoch_end
hook.
Thanks @rohitgr7, I switched to on_validation_epoch_end
.
IMHO it should be noted in the docs that on_epoch_end
refers here to the end of any of the train/val/test "epochs"; for me, an epoch means a train, then an optional validation and finally another optional test part put together -> thus, on_epoch_end means run after train&val&test, not after each part. Anyway, got that cleared out, thanks a lot for your time, closing now.
Has the order in which these methods are called been changed? I can't see any indication of what is called when in the docs. I've had to look through the code.
🐛 Bug
The on_epoch_end is called before the epoch ends.
What I'm doing:
in the pl model I have the
validation_epoch_end
which computes an accuracy which I log withself.log("val/meta_acc", meta_acc)
I have a callback with a single method defined as:
trainer gets this callback as
callbacks = [PrintAndSaveCallback],
Up until I upgraded from 1.0.8 (directly now to 1.2.2 and 1.2.3 today) everything was working fine. The validation_epoch_end was logging metrics and in the callback I read them fine. Now, I'm getting :
This is because the
metrics
in my code which is actually just thetrainer.callback_metrics
is an empty dict. Furthermore, this fails in epoch 0, after the sanity check ,which finishes fine (I printed the metrics and in prints the 0 accuracy I expected from the sanity check).What I tried is to switch the
on_epoch_end
withon_validation_epoch_end
and it works. This led me to the conclusion that since on_epoch_end returns an empty dict whileon_validation_epoch_end
returns a filled-in dict from thevalidation_epoch_end
in the pl module, theon_epoch_end
is called in an incorrect order.Again, this worked well with 1.0.8. (don't know in what latter version this behaviour changed).
Expected behavior
on_epoch_end
should have the metrics fromvalidation_epoch_end
Environment