facebookresearch / higher

higher is a pytorch library allowing users to obtain higher order gradients over losses spanning training loops rather than individual training steps.
Apache License 2.0
1.59k stars 123 forks source link

Improving documentation for copy_initial_weights #30

Closed renesax14 closed 4 years ago

renesax14 commented 4 years ago

I suggest to improve the English used for the documentation in the following:

copy_initial_weights – if true, the weights of the patched module are copied to form the initial weights of the patched module, and thus are not part of the gradient tape when unrolling the patched module. If this is set to False, the actual module weights will be the initial weights of the patched module. This is useful when doing MAML, for example.

For example, "the weights of the patched module are copied to form the initial weights of the patched module" doesn't make sense to me because when the context manager is initiated a patched module does not exist yet. So it is unclear what we are copying from and to where (and why copying is something we want to do).

Also, "unrolling the patched module" does not make sense to me. We usually unroll a computaiton graph caused by a for loop. A patched module is just a neural net that has been modified by this library. Unrolling is ambiguous.

Also, there isn't a technical definition for "gradient tape".

Also, when describing what false is, saying that it's useful for MAML isn't actually useful because it doesn't even hint why it's useful for MAML.

Overall, it's impossible to use the context manager because it's unclear what that flag is suppose to be doing (and seems critical, which is weird it that is has default values).


Related:

egrefen commented 4 years ago

Hello! Sorry to hear you've been having issues with our documentation. We can definitely improve it!

As you probably know, higher is meant to be used to define an "unrolled optimization loop" through which you can backpropagate gradient of a meta-loss with regard to suitable meta-parameters within the loop (e.g. learning rates, parametric loss functions, etc). Because when you call metaloss.backward() gradient is propagated to all tensors with requires_grad=True, this typically includes the target model parameters (since they are being optimized in the inner loop).

To avoid accidentally accumulating gradients on the model parameters when optimizing meta-parameters with higher, and to avoid imposing on the coder the need to place things like model.zero_grad() after an inner loop or in other places they wouldn't expect to need to, we by default copy the parameters or the model being unrolled at the beginning of an unrolled loop, to cut the gradient path back to the original model parameters.

There are two reasons you might not want to do this:

  1. Memory: this creates a technically unnecessary copy of the model weights for the lifetime of the unrolled loop (usually short). Because unrolling the optimization loop is O(n) in the loop length n, this is usually not noticeable, but if it's an issue, you might want to disable the copy and manually zero grads on the original model after the inner loop.
  2. MAML (and related methods): here the meta-parameters are the model parameters at the beginning of the loop (basically meta-learning initialization) and so obtaining gradient on the original model parameters with the call to metaloss.backward() is precisely what you hope to achieve, so here you want to suppress the copying.

The references to "gradient tape" are confusing, I agree, and I should use the term "implicit backward graph" or something.

So two questions for you, @renesax14:

A. Does this explanation clarify things? B. Would you be interested in suggesting an alternative docstring from the keyword, which I can update the codebase with? If not, I'll draft something when I've head back from you and solicit your feedback, but a small contribution in this form would be very appreciated!

renesax14 commented 4 years ago

@egrefen That was very useful feedback. Thank you for that. I will think about what a better docstring would be but feel free to suggest stuff and I am happy to provide feedback.

I think I understand what the docstring means but let me try to re-phrase to be sure.

When copy_initial_weights=True then it means every time that we reach that context manager we create a new copy of the parameters being fed before enter the context manager. The way I understand this is to essentially implement the way "norma" gradient descent would work. When we do W^<t> := W^<t-1> - eta*Grad(W^<t-1>,L_train) normally we use the weights from the previous iteration (no meta-learning). So here, we allow the meta-parameters be treated as in normal gradient descent and not "stringing them" all the way to the beginning. That's what True does.

When false, then the meta-update is also part of the computation graph and it goes back all the way to the first update ever done by any code we wrote (assuming the flag is consistent through the meta-training).

I think that's what going on base on your description. I will think if there is a more docstring way to explain this, but first, is this correct?

(I also noticed that I didn't realize that MAML actually unrolled so far back, I always assumed that after a meta-param update step it proceeded to forget the previous meta-update step [as in not including it in the computation graph], but reading what you said, I don't see why we wouldn't want to go back all the way to the beginning and in MAML it makes sense, though having it false and returning what those values as a meta-learned prior ready to fine tune seemed like a fine option too).

egrefen commented 4 years ago

When copy_initial_weights=True then it means every time that we reach that context manager we create a new copy of the parameters being fed before enter the context manager.

Yes. And if False, the parameters of the original model are used at the beginning of the unrolled loop, and receive gradient.

The way I understand this is to essentially implement the way "norma" gradient descent would work.

Not sure I follow what you mean...

When we do W^ := W^ - eta*Grad(W^,L_train) normally we use the weights from the previous iteration (no meta-learning).

That's correct. Here we treat this update as differentiable and keep track of the intermediate computations in order to backpropagate meta-loss. This is done whether or not copy_initial_weights is true or false.

So here, we allow the meta-parameters be treated as in normal gradient descent and not "stringing them" all the way to the beginning. That's what True does.

No, copy_initial_weights=True just refers to the fact that W^0 (where t=0 is the first timestep of the unrolled loop) are copies of the parameters of the model at the time where the unroll is started, and if copy_initial_weights=False they are a reference to the original parameters.

renesax14 commented 4 years ago

No, copy_initial_weights=True just refers to the fact that W^0 (where t=0 is the first timestep of the unrolled loop) are copies of the parameters of the model at the time where the unroll is started, and if copy_initial_weights=False they are a reference to the original parameters.

Sorry if I still don't get it. I think my main confusion is if we are unrolling the outerloop of the meta-training too or just the inner loop. I understand that we take the derivative of the wrt to the inner loop optimization step. But if we have more than 1 iteration of meta-training, does copy_initial_weights affect if we differentiate over the meta-training step too or not?

So for MAML we might have:

## let W^<out_idx, inner_idx> denote the parameters of the model at inner index inner_idx and outer index out_idx
# first out-loop
W^<0,0> = initialize
W^<0,1> = W^<0> - eta*grad(W^<0>,L_train) # inner-loop
W^<0,2> = W^<1> - eta*grad(W^<1>,L_train) # inner-loop
W^<1,0> = W^<2> - eta*grad(W^<2>,L_val) # SGD on meta-loss
# net outer-loop
W^<1,1> = W^<1,0> - eta*grad(W^<1,0>,L_train) # inner-loop
W^<1,2> = W^<1,1> - eta*grad(W^<1,1>,L_train) # inner-loop
W^<2,0> = W^<1,2> - eta*grad(W^<1,2>,L_val) # SGD on meta-loss

So there are two option, each time we start an outer-loop we "detach" the computation graph or we connect everything to the beginning and thus have a (potentially) massive computation graph. Is this second option ever done?

egrefen commented 4 years ago

MAML is a special case of inner-loop meta-learning (see this paper for a detailed formalization of what we mean by this). In MAML, the only thing the outer loop does is a meta-batch of new tasks, unroll the inner loop, take a meta-loss, calculate meta-gradients, and update the "initialization" of the model (we assume the model is in a locally initial state at the beginning of each k-shot learning episode) using the meta-gradients. In this sense, the meta-parameters being learned are the model parameters (or rather the initial parameters of the model).

So when we unroll the inner loop, instead of unrolling starting with a copy of the model (which is what we would do it we wanted to avoid meta-gradient accidentally propagating to the model parameters), we use the model itself as the initial point of the unrolled loop. As you unroll the loop, the state of the model at each inner step > 0 is not a copy or the original model, but rather a node in differentiable graph which links back (possibly via other nodes) to parameters at the beginning of the unrolled loop. Hence, when I do backwards the output of a differentiable function (e.g. the meta-loss) of the last state in this unrolled loop, gradient is guaranteed to propagate to every part of the unrolled loop including the initial state, allowing me to train parts of the optimization process that matter to me (the meta-variables). And in the case of MAML, this is precisely the initial state of the beginning of the unrolled loop, which in turn is the model state itself in the outer loop, which is why I do not want to copy in this specific case (as that would destroy the gradient pathway to what I am trying to optimize).

Does this make sense?

JonMuehlst commented 4 years ago

This discussion has helped clarify the purpose of "copy_inital_weights". I still have a question: Why does copying the weights render them as "not part of the gradient tape..."?

egrefen commented 4 years ago

I still have a question: Why does copying the weights render them as "not part of the gradient tape..."?

The copy receives gradient, but the "backwards" connection to the original weights is severed.

renesax14 commented 4 years ago

@egrefen

Apologies for dragging this further but I prefer to be safe than sorry and admit I don't understand still.

I would have never expected higher to create copies of anything. I always want to train the initialization weights from the beginning...why wouldn't I? I would have expected that zeroing the outer optimizer meta_opt would clear those gradients out and nothing expected would accumulate...is that not right that clearing the grads of outer optimizer is enough as in:

def train(db, net, device, meta_opt, epoch, log):
    net.train()
    n_train_iter = db.x_train.shape[0] // db.batchsz

    for batch_idx in range(n_train_iter):
        start_time = time.time()
        # Sample a batch of support and query images and labels.
        x_spt, y_spt, x_qry, y_qry = db.next()

        task_num, setsz, c_, h, w = x_spt.size()
        querysz = x_qry.size(1)

        # TODO: Maybe pull this out into a separate module so it
        # doesn't have to be duplicated between `train` and `test`?

        # Initialize the inner optimizer to adapt the parameters to
        # the support set.
        n_inner_iter = 5
        inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1)

        qry_losses = []
        qry_accs = []
        for i in range(task_num):
            with higher.innerloop_ctx(
                net, inner_opt, copy_initial_weights=False
            ) as (fnet, diffopt):
                # Optimize the likelihood of the support set by taking
                # gradient steps w.r.t. the model's parameters.
                # This adapts the model's meta-parameters to the task.
                # higher is able to automatically keep copies of
                # your network's parameters as they are being updated.
                for _ in range(n_inner_iter):
                    spt_logits = fnet(x_spt[i])
                    spt_loss = F.cross_entropy(spt_logits, y_spt[i])
                    diffopt.step(spt_loss)

                # The final set of adapted parameters will induce some
                # final loss and accuracy on the query dataset.
                # These will be used to update the model's meta-parameters.
                qry_logits = fnet(x_qry[i])
                qry_loss = F.cross_entropy(qry_logits, y_qry[i])
                qry_losses.append(qry_loss.detach())
                qry_acc = (qry_logits.argmax(
                    dim=1) == y_qry[i]).sum().item() / querysz
                qry_accs.append(qry_acc)

                # Update the model's meta-parameters to optimize the query
                # losses across all of the tasks sampled in this batch.
                # This unrolls through the gradient steps.
                qry_loss.backward()

        meta_opt.step()
        meta_opt.zero_grad() # <---- THIS ZEROS GRADIENTS OUT

code adapted from: https://github.com/facebookresearch/higher/blob/master/examples/maml-omniglot.py

renesax14 commented 4 years ago

@egrefen

Based on this response (for your bullet points on why we would want to suppress copy weights) :

MAML (and related methods): here the meta-parameters are the model parameters at the beginning of the loop (basically meta-learning initialization) and so obtaining gradient on the original model parameters with the call to metaloss.backward() is precisely what you hope to achieve, so here you want to suppress the copying.

it seems that whenever we train the initialization we should have copy_initial_weights=False, is that correct?

renesax14 commented 4 years ago

based on this comment:

To avoid accidentally accumulating gradients on the model parameters when optimizing meta-parameters with higher, and to avoid imposing on the coder the need to place things like model.zero_grad() after an inner loop or in other places they wouldn't expect to need to

it makes me think that the fact the inner loop is computing gradients it might accumulate gradients from the inner loop and add them to the gradients from the outer loop (probably incorrectly). If the gradient operation is just another operation in the graph doing this would seem incorrect to me. How do we stop this? I want to train the initialization vector AND not accumulate anything accidentally or incorrectly from the inner loop.

egrefen commented 4 years ago

it seems that whenever we train the initialization we should have copy_initial_weights=False, is that correct?

Yes. We will probably be renaming this to something clearer in the next big release, and also cutting down on needless copying. For now, if you want to do something like MAML, set this to False.

egrefen commented 4 years ago

it makes me think that the fact the inner loop is computing gradients it might accumulate gradients from the inner loop and add them to the gradients from the outer loop (probably incorrectly). If the gradient operation is just another operation in the graph doing this would seem incorrect to me. How do we stop this? I want to train the initialization vector AND not accumulate anything accidentally or incorrectly from the inner loop.

If copy_initial_weights=False, no gradient is passed outside the inner loop during the inner loop's unrolling. However, if you then compute a meta-loss metaloss as a function of the weights at the end of the inner loop and call metaloss.backward(), gradients will propagate to your outer context is anything on your inner loop is a differentiable function of something in the outer context (although typically this is not the case, I cannot comment on your set up without knowing it). This is due to how backward works in pytorch, and not something we can control here. You can avoid issues by either:

  1. Using torch.autograd.grad instead of backward. This allows you to specify specific inputs with regard to which you are differentiating something (e.g. the meta_loss). You would then need to manually assign the returned gradients to the .grad attributes if you want to optimize whatever you are optimizing with an outer loop optimizer.
  2. Use backward, optimize what you need to optimize, and make sure you zero grads for everything else manually.
egrefen commented 4 years ago

Closing this issue, as it will be addressed as part of resolving issue #54. Thanks for the feedback!