Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.31k stars 3.38k forks source link

Callback hook for on_after_optimizer_step #11688

Open tridao opened 2 years ago

tridao commented 2 years ago

🚀 Feature

A callback hook for on_after_optimizer_step.

Motivation

There's a callback hook for on_before_optimizer_step, but not for on_after_optimizer_step. That would be useful for implementing an ExponentialMovingAverage callback: I'd like to update the average weights after the optimizer has updated the parameters. Doing this average weight update with on_train_batch_end hook will not be accurate, as the model weights may not get updated after every training batch (due to gradient accumulation).

cc @borda @tchaton @rohitgr7 @carmocca @awaelchli @ninginthecloud @daniellepintz

rohitgr7 commented 2 years ago

you can still do:

def optimizer_step(self, *args, **kwargs):
    super().optimizer_step(*args, **kwargs)
    # do something on_after_optimizer_step
tridao commented 2 years ago

I think the same argument can be made about other callback hooks. For example, technically one does not need on_before_backward or on_after_backward and one can override the manual_backward method of the LightningModule. Similarly, one does not need on_before_optimizer_step by overriding optimizer_step.

However, having a callback hook makes it easy to implement a callback (e.g. ExponentialMovingAverage) that can be applied to many different LightningModule's, instead of having to change the LightningModule.

Given that there are hooks for on_before_backward and on_after_backward, and on_before_optimizer_step, I think it makes sense to expose a hook for on_after_optimizer_step.

carmocca commented 2 years ago

For example, technically one does not need on_before_backward or on_after_backward and one can override the manual_backward method of the LightningModule.

This is not correct because the on_{before,after}_backward are available for both the LightningModule and Callbacks, whereas {,manual_}backward are only available for the LightningModule. Additionally, the backward hooks provide a default implementation so forgetting to call super would become a source of silent bugs.

on_before_optimizer_step was added because we needed an alternative to on_after_backward where gradients were unscaled. It is also different because it runs per optimizer, taking gradient accumulation into account.

If you want this hook to run per optimizer, the above suggestion of overriding optimizer_step is fine, otherwise on_train_batch_end should work too.

We try to add hooks only when there's a need to inject logic and no other hook can cover the use-case. If these suggestions cannot solve your problem, we are open to adding a new hook.

rohitgr7 commented 2 years ago

also just to add more here, one more reason why we have on_before_optimizer_step is that the actual training_step runs within the closure that is passed to the optimizer in a general case, and if you do something like;

def optimizer_step(self, *args, **kwargs):
    # do something before optimizer_step
    super().optimizer_step(*args, **kwargs)

it won't work correctly, since training_step hasn't been executed until optimizer_step is called with the closure.

tridao commented 2 years ago

Thanks for the explanation! I'd love to get your advice how to implement the ExponentialMovingAverage callback.

I'm porting training scripts from torchvision and timm to pytorch-lightning, and I'm trying to implement ExponentialMovingAverage (EMA) as a callback. EMA averages the model weights across training steps, and is called after the optimizer update step. EMA is used in both torchvision training script and timm's training script, which I imagine cover a lot of vision training scenarios.

I can implement EMA as part of the LightningModule (by overriding optimizer_step as you suggested). However, EMA is a general technique that applies to many models, so I'm hoping to implement that as a callback. Do you have suggestions on how to implement this?

carmocca commented 2 years ago

I'd say on_train_batch_end is the natural choice.

You can find pseudocode with all the hooks here: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#hooks

tridao commented 2 years ago

Thanks! One problem I can see with on_train_batch_end is gradient accumulation: we only want to update EMA when the model weights are updated, say every 4 training steps if gradient accumulation is 4. However, it seems that on_train_batch_end would be called every training iteration and not every 4 iterations.

carmocca commented 2 years ago

You could check if batch_idx % trainer.accumulate_grad_batches == 0

tridao commented 2 years ago

Yeah I think that's fine as a workaround. I still think there's a case for a hook that should be executed "after the the optimizer updates the parameters", which makes the code cleaner semantically. I hope that can be considered in the future. Thanks so much for your help.

carmocca commented 2 years ago

I'll cc contributors in case they think there's enough reason

cc @PyTorchLightning/core-lightning

rohitgr7 commented 2 years ago

well,

if we consider implementing this inside a callback, then it's not so trivial using any other hook considering some edge cases like, a loss is not returned from training_step, or optimization done for the last batch in case accumulation doesn't align with the total batches, for eg. total batches = 7 and accumulation_factor = 3, we take 3 optimization steps here.

this looks like a genuine use case and we have seen issues regarding EMA before, so I think we can add it. The most recent one is here: https://github.com/PyTorchLightning/pytorch-lightning/issues/10914 but looking at the callback provided there might not work with accumulating gradients.

flukeskywalker commented 2 years ago

@rohitgr7 The callback in #10914 does not appear to handle gradient accumulation. Are you sure it does?

rohitgr7 commented 2 years ago

now work with accumulating gradients.

I meant not.. typo :p

stale[bot] commented 2 years ago

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

fschlatt commented 11 months ago

To add another use case: I've been trying to implement a GradientAccumulationScheduler that updates the number of accumulation batches not based on the epoch but based on the step. This is handy, for example, for large or iterative datasets.

If the gradient accumulation size is not updated correctly, it could occur that different batches within an optimizer step are weighted differently since the losses are normalized per training step. https://github.com/Lightning-AI/lightning/blob/008a83ed5a565bfe9a4d3df17f39c1493b5f62a2/src/lightning/pytorch/loops/optimization/automatic.py#L322

To ensure this doesn't happen, the gradient accumulation size should be updated after the optimizer step.

DanTremonti commented 10 months ago

In a particular use case where, if the updated model parameters contain non-finite values, I aim to log necessary data (present in the current step) to probe and reproduce the issue. Note that logging parameters and gradients before the optimizer step is already in place.

Would on_after_optimizer_step be the appropriate location for implementing this validation check?