pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.39k stars 102 forks source link

How to update the original model parameters after calling make_functional? #280

Open trenta3 opened 2 years ago

trenta3 commented 2 years ago

As per the title, I find that updating the tensors pointed by the params returned by make_functional does not update the real parameters in the original model. Is there a way to do this? I find that it would be extremely useful to implement optimization algorithms in a way that is more similar to their mathematical description.

To provide more context I add an example script of what standard Gradient Descent should look like in this way:

import torch
from torch import nn
from functorch import make_functional

learning_rate = 0.1

def optstep(params, jacobians):
    with torch.no_grad():
        for i, param in enumerate(params):
            param.add_(jacobians[i], alpha=-learning_rate)

if __name__ == '__main__':
    model = nn.Linear(3, 5)
    x, targets = torch.randn(2, 3), torch.randn(2, 5)
    criterion = nn.MSELoss()

    print("INITIAL LOSS:", criterion(model(x), targets).item())
    # Render the model functional and compute the jacobian                                                           
    func_model, params = make_functional(model)
    def f(*params):
        out = func_model(params, x)
        return criterion(out, targets)
    jacobian = torch.autograd.functional.jacobian(f, params)

    # Ideally would train on the current input                                                                       
    optstep(params, jacobian)
    # Now compute the new loss                                                                                       
    print("NEW LOSS:", criterion(model(x), targets).item())

Executing the script shows that the parameters are not updated since the loss doesn't change

INITIAL LOSS: 1.2894147634506226
NEW LOSS: 1.2894147634506226
trenta3 commented 2 years ago

After looking a bit in the source code I've found functorch._src.make_functional.extract_weights and load_weights which allow me to do exactly what I wanted to do. Maybe those methods can be exposed and documented to allow the suggested use case?

zou3519 commented 2 years ago

Couldn't you do

def optstep(model, jacobians):
    with torch.no_grad():
        for i, param in enumerate(model.parameters()):
            param.add_(jacobians[i], alpha=-learning_rate)

?

(Also, you might want to try functorch.jacrev instead of torch.autograd.functional.jacobian -- it may be faster)

trenta3 commented 2 years ago

Is model.parameters() guaranteed to return parameters in the same order of make_functional?

If this is the case then I can surely do this, however I would like to ask that it is documented as proper behaviour on which one can rely on.

Thank you very much

zou3519 commented 2 years ago

Is model.parameters() guaranteed to return parameters in the same order of make_functional?

Yes

If this is the case then I can surely do this, however I would like to ask that it is documented as proper behaviour on which one can rely on.

Yes, we should document this

trenta3 commented 2 years ago

Thank you very much again for all this work. I think the issue can be closed as soon as the behaviour is documented.

zou3519 commented 2 years ago

@trenta3 out of curiosity, what are you using make_functional for? Are you using any of the other functorch APIs?

trenta3 commented 2 years ago

I'm currently using make_functional as well as other functorch APIs, in particular jvp and jacrev to easily write more complex optimizers that need to consider also second order information of a neural network, which is unmanageable to do in pytorch. Earlier this year I had the need to compute eigenvectors of the linearizations of some neural networks, and the ability to obtain gradients for each example separately was crucial.

If I must say it, a thing that I miss is the ability to "lazily" compute parts of the hessian, like extracting its diagonal, without using the full memory (and compute) requirement to calculate the whole hessian. More generally the ability for a pytorch user to manipulate "lazy tensors" (i.e. a thunk of computation depending on some data, but which is not eagerly executed) would be extremely useful to compute the diagonal of the hessian, as well as a lot of computations on kernel methods (like pyKeops does), but I sincerly don't know how much this can be made efficient.

kxhit commented 2 years ago

Hi! Thanks a lot for building this awesome functorch!

I have the same issue. I'm using fmodel, params, buffers = combine_state_for_ensemble(models) to stack models and optimizing the params in a training loop. After this, I wish to update each origin model's state_dict(). I can't find a nice way to achieve this. Actually what I am doing is

with torch.no_grad():
    for idx, model in enumerate(models):
        for i, param in enumerate(model.parameters()):
            param.set_(params[i][idx])

Hope I can get a nicer way to achieve this with a good tutorial. Thanks!

zou3519 commented 2 years ago

@kxhit thank you for your feedback. Could you give a little more context about why you want to update each original model's state_dict?

kxhit commented 2 years ago

@zou3519 Hi, thanks for your quick reply.

In my case, I'm training many tiny networks and need to use the up-to-date network's weights every a few steps. So I need to assign batch weights back to the original models frequently.