Open tridao opened 2 years ago
you can still do:
def optimizer_step(self, *args, **kwargs):
super().optimizer_step(*args, **kwargs)
# do something on_after_optimizer_step
I think the same argument can be made about other callback hooks. For example, technically one does not need on_before_backward
or on_after_backward
and one can override the manual_backward
method of the LightningModule. Similarly, one does not need on_before_optimizer_step
by overriding optimizer_step
.
However, having a callback hook makes it easy to implement a callback (e.g. ExponentialMovingAverage) that can be applied to many different LightningModule's, instead of having to change the LightningModule.
Given that there are hooks for on_before_backward
and on_after_backward
, and on_before_optimizer_step
, I think it makes sense to expose a hook for on_after_optimizer_step
.
For example, technically one does not need on_before_backward or on_after_backward and one can override the manual_backward method of the LightningModule.
This is not correct because the on_{before,after}_backward
are available for both the LightningModule
and Callbacks
, whereas {,manual_}backward
are only available for the LightningModule
.
Additionally, the backward hooks provide a default implementation so forgetting to call super would become a source of silent bugs.
on_before_optimizer_step
was added because we needed an alternative to on_after_backward
where gradients were unscaled. It is also different because it runs per optimizer, taking gradient accumulation into account.
If you want this hook to run per optimizer, the above suggestion of overriding optimizer_step
is fine, otherwise on_train_batch_end
should work too.
We try to add hooks only when there's a need to inject logic and no other hook can cover the use-case. If these suggestions cannot solve your problem, we are open to adding a new hook.
also just to add more here, one more reason why we have on_before_optimizer_step
is that the actual training_step
runs within the closure that is passed to the optimizer in a general case, and if you do something like;
def optimizer_step(self, *args, **kwargs):
# do something before optimizer_step
super().optimizer_step(*args, **kwargs)
it won't work correctly, since training_step
hasn't been executed until optimizer_step
is called with the closure.
Thanks for the explanation! I'd love to get your advice how to implement the ExponentialMovingAverage callback.
I'm porting training scripts from torchvision and timm to pytorch-lightning, and I'm trying to implement ExponentialMovingAverage (EMA) as a callback. EMA averages the model weights across training steps, and is called after the optimizer update step. EMA is used in both torchvision training script and timm's training script, which I imagine cover a lot of vision training scenarios.
I can implement EMA as part of the LightningModule (by overriding optimizer_step
as you suggested).
However, EMA is a general technique that applies to many models, so I'm hoping to implement that as a callback.
Do you have suggestions on how to implement this?
I'd say on_train_batch_end
is the natural choice.
You can find pseudocode with all the hooks here: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#hooks
Thanks!
One problem I can see with on_train_batch_end
is gradient accumulation: we only want to update EMA when the model weights are updated, say every 4 training steps if gradient accumulation is 4. However, it seems that on_train_batch_end
would be called every training iteration and not every 4 iterations.
You could check if batch_idx % trainer.accumulate_grad_batches == 0
Yeah I think that's fine as a workaround. I still think there's a case for a hook that should be executed "after the the optimizer updates the parameters", which makes the code cleaner semantically. I hope that can be considered in the future. Thanks so much for your help.
I'll cc contributors in case they think there's enough reason
cc @PyTorchLightning/core-lightning
well,
if we consider implementing this inside a callback, then it's not so trivial using any other hook considering some edge cases like, a loss is not returned from training_step
, or optimization done for the last batch in case accumulation doesn't align with the total batches, for eg. total batches = 7 and accumulation_factor = 3, we take 3 optimization steps here.
this looks like a genuine use case and we have seen issues regarding EMA before, so I think we can add it. The most recent one is here: https://github.com/PyTorchLightning/pytorch-lightning/issues/10914 but looking at the callback provided there might not work with accumulating gradients.
@rohitgr7 The callback in #10914 does not appear to handle gradient accumulation. Are you sure it does?
now work with accumulating gradients.
I meant not.. typo :p
This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!
To add another use case: I've been trying to implement a GradientAccumulationScheduler that updates the number of accumulation batches not based on the epoch but based on the step. This is handy, for example, for large or iterative datasets.
If the gradient accumulation size is not updated correctly, it could occur that different batches within an optimizer step are weighted differently since the losses are normalized per training step. https://github.com/Lightning-AI/lightning/blob/008a83ed5a565bfe9a4d3df17f39c1493b5f62a2/src/lightning/pytorch/loops/optimization/automatic.py#L322
To ensure this doesn't happen, the gradient accumulation size should be updated after the optimizer step.
In a particular use case where, if the updated model parameters contain non-finite values, I aim to log necessary data (present in the current step) to probe and reproduce the issue. Note that logging parameters and gradients before the optimizer step is already in place.
Would on_after_optimizer_step
be the appropriate location for implementing this validation check?
🚀 Feature
A callback hook for
on_after_optimizer_step
.Motivation
There's a callback hook for
on_before_optimizer_step
, but not foron_after_optimizer_step
. That would be useful for implementing an ExponentialMovingAverage callback: I'd like to update the average weights after the optimizer has updated the parameters. Doing this average weight update withon_train_batch_end
hook will not be accurate, as the model weights may not get updated after every training batch (due to gradient accumulation).cc @borda @tchaton @rohitgr7 @carmocca @awaelchli @ninginthecloud @daniellepintz