Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.26k stars 3.38k forks source link

Updating 1.0.3 -> 1.0.4 breaks pytorch module hooks #5131

Closed ludwigwinkler closed 3 years ago

ludwigwinkler commented 3 years ago

1.0.3 -> 1.0.4 breaks pytorch-hooks

What is your question?

I implemented a custom optimizer that stores the input and the incoming gradients for each layer and recomputes the gradients in a per-sample fashion (wich I do because I want to use K-FAC for my optimization).

I'm using the pytorch hooks on modules (i.e. module123.register_forward_pre_hook(_save_input_function) which will store the inputs into the state dictionary of the optimizer. I use the same mod123.register_backward_hook(_save_gradient_function) for the backward gradients) and when calling optimizer.step().

As I updated from 1.0.1 to 1.1.0 it broke the hooks that are required in the KFAC optimizer. I checked different versions and from 1.0.3 to 1.0.4 something changed such that the Trainer ignores/steps over the hooks in the forward and backward pass.

I checked the release notes but couldn't identify anything essential that would warrant these breaking changes.

Additionally they were issue popping up with manually modifying gradients in other sampling-based types of optimizers.

These issue do not exist for the standard optimizers straight from PyTorch optim class.

Code

It would be a huge codebase as the implementation of KFAC is a bit sophisticated. But if nothing helps I could write an minimal working example from scratch.

williamFalcon commented 3 years ago

@tchaton

williamFalcon commented 3 years ago

@ludwigwinkler i suspect using the boring model with an optimizer that overrides a hook will replicate this problem

ludwigwinkler commented 3 years ago

Hi,

this is a boring model with an Optimizer that uses forward and backward hooks to do stuff with parameters https://colab.research.google.com/drive/1z69hmWUybxW58iGMR9EQ3wlT5jwxJxzt?usp=sharing

Let me know if something is unclear/needs more comments or needs polishing.

Thanks for all the effort you put into this project. =)

ludwigwinkler commented 3 years ago

@tchaton Did you have time to look at this?

lizhitwo commented 3 years ago

I had the same bug, but it turns out to be: when retiring, pytorch_lightning.core.optimizer.LightningOptimizer calls its __del__, which in turn calls KFAC's __del__, which deletes all the hooks.

From pytorch lightning's stand point, it is very horrifying that it automatically calls my optimizer's __del__ even though I am still using my optimizer somewhere.

ludwigwinkler commented 3 years ago

@lizhitwo thank you for your insight! where can I find the deletion call in the source code? Or should I overwrite my custom optimizers del function?

lizhitwo commented 3 years ago

There’s no explicit calling but you can try it by constructing a lightning optimizer that wraps KFAC, and then set it to None while retaining KFAC, and then call garbage collector (gc.collect()). I think you should then see KFAC del called although you still have a KFAC object.  On Fri, Jan 22, 2021 at 9:12 AM, ludwigwinkler notifications@github.com wrote:
@lizhitwo thank you for your insight! where can I find the deletion call in the source code? Or should I overwrite my custom optimizers del function?

—You are receiving this because you were mentioned.Reply to this email directly, view it on GitHub, or unsubscribe.

tchaton commented 3 years ago

Dear @ludwigwinkler @lizhitwo,

Great find ! We attached most function from your optimizer to the LightningOptimizer, so it exposes the same attributes and properties.

We didn't thought about __del__ function being triggered and accidentally deleting your hooks.

I couldn't reproduce the bug you shared. The notebook contained an error: The provided closure wasn't being called, so the hooks weren't triggered.

I have made a PR there: https://github.com/PyTorchLightning/pytorch-lightning/pull/6305. Please, check it does cover your use case.

Sorry for the inconvenience.

Best, T.C

ludwigwinkler commented 3 years ago

Sorry for the long delay.

So I revisited my old code but the deletion of the hooks still persists in version 1.2.6. :(

You can check the colab boring model here: https://colab.research.google.com/drive/1z69hmWUybxW58iGMR9EQ3wlT5jwxJxzt?usp=sharing

The optimizer's internal state wont be able to reference the results of the hooks which are stored in the internal dictionary.

lizhitwo commented 3 years ago

Hi ludwigwinkler, from the stack it seems that your code runs the optimizer step before the closure (def step(self, closure=None)'s closure). Since some time ago, pl needs us users to either pass closure to parent classes to handle, or run the closure ourselves, and the closure is what actually calls training_step, etc. Otherwise no forward pass is even evaluated, and that leads to no hook being evoked.

To pl devs: Is it also a bug that in this tutorial the closure is only evaluated in some iterations, but just ignored (no forward passes) in others?

ludwigwinkler commented 3 years ago

Hi lizhitwo, I'm only familiar with the concept of closures in optimizers like L-BGFS which recompute multiple values and/or states to predict the right step size in optim.step(closure=closure).

Do you mean that the forward pass should be a encapsulated in a closure and subsequently passed to the optimizer such that the forward pass itself is executed within the optimizer?

Thanks for your effort to help me and everybody else out on this, Ludwig

lizhitwo commented 3 years ago

I'm only familiar with the concept of closures in optimizers like L-BGFS

yeah outside pytorch lightning, closure is only useful in LBFGS. In pl, everything uses closure whose purpose I think is to support all optimizers at once, including LBFGS.

Do you mean that the forward pass should be a encapsulated in a closure

that’s already done automatically in pl. Inside your optimizer step, you can try to print the input closure variable, which will be a function instead of none. You can add printing to your training step and your optimizer step to observe this behavior. You should be able to see that the training step was not run if you don’t call closure, but if you call it (closure()) you should see the training step called.

ludwigwinkler commented 3 years ago

Wow ... never would have that that simply adding closure() inside my optimizer.step() would solve the entire problem. :) I wasn't aware of the optimizer functionality in pl. My bad ...

Wonderful! Thank you very much @lizhitwo for your help!

lizhitwo commented 3 years ago

Glad to help someone who fell into the same pit as myself. Frankly I think pytorch lightning should throw an warning if closure() is not called in optimizer_step...