facebookresearch / higher

higher is a pytorch library allowing users to obtain higher order gradients over losses spanning training loops rather than individual training steps.
Apache License 2.0
1.58k stars 123 forks source link

Why does higher need to deep copy the parameters of the base model and the use of override? #58

Open renesax14 opened 4 years ago

renesax14 commented 4 years ago

To the point: Why does higher need to deep copy the parameters of the base model (usually the initialization params)?


More expanded question

I noticed this line of code:

https://github.com/facebookresearch/higher/blob/8f0716fb1663218324c02dabdba26b639959cfb6/higher/optim.py#L101

Which I've removed on my copy of higher and removed some bugs I was experiencing (like having the gradients of the parameters from the beginning [i.e. initialization] be unexpectedly zero for the outer loop optimizer). I reference it in issue https://github.com/facebookresearch/higher/issues/30.

I don't understand a couple of things. If when going into the context manager creates a deep copy of the parameters, wouldn't that mean the (outer loop) optimizer would be computing the gradients with respect to parameters no present in the computation graph? Since:

  1. the parameters of the differentiable/inner optimizer are a deep copy compared to the initial parameters/weights
  2. the outer optimizer (e.g. Adam) would have the original/initial parameters, so the gradient of these should always be zero.

That is the only explanation that I can think of to explain my issues in the past (gradients being zero unexpectedly) however it seems the higher MAML tutorial works, which should go against my theory. If my theory is right at the end of the inner loop of MAML and when the outer optimizer (usually Adam) computes the gradients, they should be zero (which I have observed sometimes). But I assume they are NOT zero, otherwise that tutorial wouldn't work.

So I am inquiring about the need to use deep copy when creating inner optimizers. What is its purpose and why is it not causing the issues I describe in the original MAML tutorial in higher. How is it that the deep copy doesn't break the forward pass and thus the whole computation of gradient wrt the initialization that the outer optimizer would use?


https://stackoverflow.com/questions/62437960/why-does-higher-need-to-deep-copy-the-parameters-of-the-base-model-to-create-a-f

renesax14 commented 4 years ago

Pasting some original response about this from @egrefen:

That line of code is important, as we want to safely branch off the state of the optimizer as used in the outer loop and return to it (or not touch it in the first place, which is what we do with the copy here) at the end of the unrolled inner loop.

Use override, please.

Override is a kwarg for differentiable optims (at creation, or step time, and you can also use it with the context manager) which allows you to use arbitrary tensors instead of values held in the optimizer state. For example, you could override the learning rate with a tensor which requires grad, which would allow you to unroll your loops, take gradient of the meta-loss with regard to the learning rate, and update this tensor.

See https://higher.readthedocs.io/en/latest/optim.html for details,

32 (comment) for

a similar explanation, and https://github.com/denisyarats/densenet_cifar10 for an example.

renesax14 commented 4 years ago

I don't understand what:

That line of code is important, as we want to safely branch off the state of the optimizer as used in the outer loop and return to it (or not touch it in the first place, which is what we do with the copy here) at the end of the unrolled inner loop.

means. I still think (perhaps naively) that removing that deep copy is the best solution to removing the gradient zero problem I've been experiencing on my initialization.

renesax14 commented 4 years ago

Apologies for being so demanding but I don't think I understand what the override option means:

override (optional) – a dictionary mapping optimizer settings (i.e. those which would be passed to the optimizer constructor or provided within parameter groups) to either singleton lists of override values, or to a list of override values of length equal to the number of parameter groups. If a single override is provided for a keyword, it is used for all parameter groups. If a list is provided, the ith element of the list overrides the corresponding setting in the ith parameter group. This permits the passing of tensors requiring gradient to differentiable optimizers for use as optimizer settings.

in particular:

This permits the passing of tensors requiring gradient to differentiable optimizers for use as optimizer settings.

What does that mean? What are "optimizer setting"? What does "the passing of tensors requiring gradient to differentiable optimizers"

Wouldn't one always expect the parameters of parametrizing the optimizer AND the initialization to be differentiable wrt outer optimizer by default? I don't know a case when we don't want the weights in the differentiable optimizer not to be differentiable (e.g. meta-lstm meta-learning or initialization training with MAML).

renesax14 commented 4 years ago

for reference, meta-lstm is this paper:

https://openreview.net/pdf?id=rJY0-Kcll

I am happy to contribute an implementation to the example (and have it checked) once all these things are made clear since my implementation currently uncomments this deep copy line of code that I don't think is correct.

renesax14 commented 4 years ago

I suspect that a lot of this might get cleared up once the meaning of "copy_initial_weights" is clarified in: https://github.com/facebookresearch/higher/issues/54

egrefen commented 4 years ago

Hello sorry I haven't had time to look into this. I'll read through this ASAP and get back to you regarding this issue (hopefully later this week).

renesax14 commented 4 years ago

I suspect that a lot of this might get cleared up once the meaning of "copy_initial_weights" is clarified in: #54

copy weights is much clearer to me and did NOT resolve this deep copy issue/question I have.

When it's true then we detach the parameters and make a copy of them. That probably means the original parameters of the base model are not trained.

When false it's only cloned. I assume that's done to avoid mistakes with in-place ops but gradients can still flow so the initialization of the base model should be trainable.

However, that can be achieved without deep copying the parameters of the original optimizer e.g. by doing:

self.param_groups = other.param_groups

thus, deep copying remains a mystery to me.

renesax14 commented 4 years ago

perhaps crucially related to this issue: https://github.com/facebookresearch/higher/issues/62

egrefen commented 4 years ago

Sorry for the delay in looking into this. I remain the main developer of this library but it's not my main role in Facebook, so I only have time to chase these issues here and there. I am reading through your comments and trying to figure out how to best proceed.

egrefen commented 4 years ago

The short answer: differentiable optimizers branch off the state of their non-differentiable version because the patched modules they optimize are themselves branched off from the original modules. In inner loop meta-learning, you typically branch off like this, do some inner loop computation, compute the meta-gradient, and use that meta-gradient to update some meta-variables. The result of this inner loop computation is then discarded, so you generally don't want to accidentally backpropagate stuff or update the module or optimizer state you branched off from. There are cases where you do want to do this, e.g. MAML in the module case, and for that we use copy_initial_weights=False, or the case where the meta-variables are part of the optimizer, in which case we use override to replace the usually non-differentiable variables used in the optimizer with tensors, which we then need to assign back to the optimizer state outside of the unrolled loop (see https://github.com/denisyarats/densenet_cifar10 for example).

Long answer: I see that this is all guided by your desire to implement meta-LSTMs. Let's see if we can resolve that issue in #62 first, and we can then revert here to see if anything needs to be done.