facebookresearch / higher

higher is a pytorch library allowing users to obtain higher order gradients over losses spanning training loops rather than individual training steps.
Apache License 2.0
1.6k stars 124 forks source link

Feature request: utility functions to allow stopping meta-gradient propagation #10

Open llucid-97 opened 5 years ago

llucid-97 commented 5 years ago

Hi there, it would be really useful if higher's API allowed blocking gradients at a point in the graph

Use Cases:

Proposed API The simplest and most intuitive way in my opinion is to make the differentiable optimizers use with pytorch's existing .state_dict() and .load_state_dict() , so we can call them on pytorch's normal modules and optimziers like:

nn_module_model.load_state_dict(functional_model.state_dict())
normal_optimizer.load_state_dict(differentiable_optimizer.state_dict())

To continue training with meta gradients, we can then re-create the higher versions of these. The functional models already do this, but the differentiable optimizers don't. I've hacked together a patch that works for my use case (2-step unrolling), but I doubt would work in general.

An extension for convenience Lastly, a more convenient (but less intuitive) extension to this idea would be to have the old .detach() and .detach_() methods for torch tensors apply to the Functional models. it didn't amke sense for nn.Modules since they weren't differentiable, but it would be so convenient to for Functional models since they (their parameters) are now differentiable too:

# critic needs to be trained and backprop its meta gradients
critic_loss = F.mse(
    fmodel_critic(state,action),
    target_q_value
) 

# actor also needs to be trained and backprop its meta-gradients, but critic used in training it should NOT backprop policy's loss into its own parameters
fmodel_critic_for_training_actor = fmodel_critic.detach() # copy of fmodel_critic that doesn't alter original's params
actor_loss = - fmodel_critic_for_training_actor(
    state,
    fmodel_actor(state)
) 
# backpropagation and updates now ignore the detached copy
diffopt_critic.step(critic_loss)
diffopt_actor.step(actor_loss)

Of course, these are just for convenience and efficiency since the state_dict methods could functionally do the same things, albeit slower

egrefen commented 5 years ago

Thanks, this is super. Just FYI, I am on leave all of November, and most of December. If this is urgent, please flag this here. I will try and find time when I return early December to implement an appropriate and robust state_dict method for patched modules and differentiable optimizers which does the right thing.

In the meantime, if you fancy having a stab at doing this yourself, we welcome contributions and it would be very helpful indeed.

AntoineHX commented 5 years ago

Hi ! I could also really use those type of utility function ! Because for now the only working way, i found, to stop the meta-gradient propagation (and save of weights of the K last states) with an arbitrary K is to perform a complete copy of the model / optimizer states and gradients before re-instantiation of fmodel and diffoptim. I also tried to directly alter the fast_weight memory of the fmodel to remove states from the memory but it didn't stop the memory to ever increase with each iteration until saturation of the GPU...

llucid-97 commented 5 years ago

@AntoineHX I saw no speedup in doing this detach version over just copying (in my experiments with 1-step meta-gradients it was actually slower), and I'm still not sure why. I'll investigate this again in about a month to figure out why, but for now I'm just getting by with the explicit copy version

The "memory leak" is caused by the autograd graph for old iterations being accidentally held onto by these persistent tensors:

So you need to detach both of them after you've finished the loops on optimization and meta-optimization.

If you'd like to try it, I made this .detach_() method for the DifferentiableOptimizer base class:

    def detach_(self):
        """Removes all params from their compute graph in place."""
        # detach param groups
        for group in self.param_groups:
            for k, v in group.items():
                if isinstance(v,_torch.Tensor):
                    v.detach_().requires_grad_()

        # detach state
        for state_dict in self.state:
            for k,v_dict in state_dict.items():
                if isinstance(k,_torch.Tensor): k.detach_().requires_grad_()
                for k2,v2 in v_dict.items():
                    if isinstance(v2,_torch.Tensor):
                        v2.detach_().requires_grad_()
AntoineHX commented 5 years ago

@ihexx Thanks for the quick reply ! I tried to use your function but for me, it didn't work as intended. After using .detach_(), at the end of next inner iteration, the diffopt.step(loss) get a

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

maybe it's because i had to change your function so that v2.detach_().requires_grad_() to v2=v2.detach().requires_grad_() so that it isn't performed in-place (because views cannot be detached in-place). I don't really have the time right now to dig any deeper so I'll believe you that it doesn't get any faster than the copy version !

egrefen commented 4 years ago

@AntoineHX @ihexx We will look at this issue next week. In the meantime, in #15 we fixed some memory leak issues from #6. Could you please let us know if you are still seeing memory leaks after pulling a fresh version of the repo (and updating higher in your env if you pip installed it)?

egrefen commented 4 years ago

Hello. I've returned from leave and allocated some time to look into this issue over the next two weeks. Hopefully we'll make some progress and report back, or come back to you with questions.

egrefen commented 4 years ago

@ihexx Note the patched modules already support this, i.e. you can do something like

module = ...
fmodule = higher.monkeypatch(module)
with torch.nograd():
    for p in module.parameters():
        p.add_(1)  # modify original params so that fmodule params are different
module.load_state_dict(fmodule.state_dict())  # restore params from fmodule

Does this fit your needs on the module side of things?

I'll look at the optimizers later this week or early next week as there are other bugs that have higher priority.

llucid-97 commented 4 years ago

Yeah that looks great and its exactly what I had in mind, thank you. I'll close this issue for now since Higher now implements the core features I wanted, and the _detach() idea was a dud in retrospect (no matter how I sliced the graph, it still needed to copy the parameters somewhere , so it didn't turn out any more efficient, and the slicing made the code much messier than just using this explicit copy style)

I'll try modifying some of my old code to use this in a couple of weeks, and re-open if I run into issues using it.

Thanks for the support :)

egrefen commented 4 years ago

@ihexx I'd like to check that it works in the general sense before closing the issues. I also need to check that this works for optimizers, which I think may require more work, so I'll keep this issue open...

AntoineHX commented 4 years ago

@ihexx Note the patched modules already support this, i.e. you can do something like

module = ...
fmodule = higher.monkeypatch(module)
with torch.nograd():
    for p in module.parameters():
        p.add_(1)  # modify original params so that fmodule params are different
module.load_state_dict(fmodule.state_dict())  # restore params from fmodule

Does this fit your needs on the module side of things?

I'll look at the optimizers later this week or early next week as there are other bugs that have higher priority.

Personally, i get the same error that i got when trying for the first time the detach_() method that @ihexx proposed :

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

Also, i'm not sure to follow why this approach should be better than detaching from the graph. Aren't we making unnecessary copy to detach ? copying from fmodel to model and from model back to fmodel (if we want to continue, higher gradient tracking).

For now, i found that using the detach_() method for DifferentiableOptimizer as well as performing some kind of detach for the fmodel seemed to work best :

tmp = fmodel.fast_params
fmodel._fast_params=[]
fmodel.update_params(tmp)
for p in fmodel.fast_params:
    p.detach_().requires_grad_()

Even if it's seems to me that this still perform unnecessary copies of the parameters...

thomas0809 commented 4 years ago

@ihexx Note the patched modules already support this, i.e. you can do something like

module = ...
fmodule = higher.monkeypatch(module)
with torch.nograd():
    for p in module.parameters():
        p.add_(1)  # modify original params so that fmodule params are different
module.load_state_dict(fmodule.state_dict())  # restore params from fmodule

Does this fit your needs on the module side of things?

I'll look at the optimizers later this week or early next week as there are other bugs that have higher priority.

Hi,

I am wondering are there any updates on the state_dict() for the optimizers?