locuslab / deq

[NeurIPS'19] Deep Equilibrium Models
MIT License
724 stars 80 forks source link

Question about Remove Hook #24

Closed QiyaoWei closed 2 years ago

QiyaoWei commented 2 years ago

Dear Shaojie,

Hi there! This is Qiyao, a huge fan of your works! I am writing to ask a question about the lines. I notice that if I remove these lines the training does not work, but I am having a hard time figuring out why? In my understanding, the program should never be creating more than one hook in a single forward pass, so I don't see the purpose of having this check here? For example, this tutorial does not check for the hook, so I am confused as to what is happening here?

jerrybai1995 commented 2 years ago

Hi @QiyaoWei ,

Great question. In the tutorial, as you might have noticed, the fixed point z was cloned into z0 and passed through another layer. This is therefore an extra computation cost.

In contrast, in the implementation provided in this repo, we don't even need to pay this extra computation cost (or cloning anything). This requires us to directly apply the hook on z1s. However, if we do NOT remove the hook, then this line will recursively call the hook (as autograd.grad will call backward hook), and the program will hang accordingly.

I hope this clarifies things for you.

QiyaoWei commented 2 years ago

Gotcha. If I may ask a follow-up question, also about hooks---Say I want to do something like Jacobian Regularization, only that the regularization term comes from the solver, i.e. aside from z1, I also get my regularization term from this return. Is there a way to allow that regularization loss to backprop through the solver while keeping everything else intact? Basically I want to keep the forward solver wrapped in torch.no_grad(), but somehow allow my regularization loss to be outside torch.no_grad(). I'm not sure if hooks will work in this case, so is that even possible?

jerrybai1995 commented 2 years ago

I don't think I fully understand. If your goal is to backprop through the solver, then you will have to pay all the intermediate activation memory costs anyway--- so there's no point to do torch.no_grad(). Or do you believe your regularization loss will only use a tiny portion of the solver information?

QiyaoWei commented 2 years ago

yep that's exactly right. Ideally my regularization loss would only need to use less than half of the solver stacktrace, so I think there is still merit in investigating whether I can keep the original hook routine

jerrybai1995 commented 2 years ago

Then maybe it's possible to do this at a finer granularity - that is, put the torch.no_grad in the solver implementation and only keep the part that you do need to be differentiable. Another way is to break the fixed-point solving into 2 parts: one part within torch.no_grad() and the other part with torch.enable_grad().

QiyaoWei commented 2 years ago

Ah I see. That makes sense. Thanks a lot for the quick response!