Open llucid-97 opened 5 years ago
Thanks, this is super. Just FYI, I am on leave all of November, and most of December. If this is urgent, please flag this here. I will try and find time when I return early December to implement an appropriate and robust state_dict
method for patched modules and differentiable optimizers which does the right thing.
In the meantime, if you fancy having a stab at doing this yourself, we welcome contributions and it would be very helpful indeed.
Hi ! I could also really use those type of utility function ! Because for now the only working way, i found, to stop the meta-gradient propagation (and save of weights of the K last states) with an arbitrary K is to perform a complete copy of the model / optimizer states and gradients before re-instantiation of fmodel and diffoptim. I also tried to directly alter the fast_weight memory of the fmodel to remove states from the memory but it didn't stop the memory to ever increase with each iteration until saturation of the GPU...
@AntoineHX I saw no speedup in doing this detach version over just copying (in my experiments with 1-step meta-gradients it was actually slower), and I'm still not sure why. I'll investigate this again in about a month to figure out why, but for now I'm just getting by with the explicit copy version
The "memory leak" is caused by the autograd graph for old iterations being accidentally held onto by these persistent tensors:
So you need to detach both of them after you've finished the loops on optimization and meta-optimization.
If you'd like to try it, I made this .detach_()
method for the DifferentiableOptimizer base class:
def detach_(self):
"""Removes all params from their compute graph in place."""
# detach param groups
for group in self.param_groups:
for k, v in group.items():
if isinstance(v,_torch.Tensor):
v.detach_().requires_grad_()
# detach state
for state_dict in self.state:
for k,v_dict in state_dict.items():
if isinstance(k,_torch.Tensor): k.detach_().requires_grad_()
for k2,v2 in v_dict.items():
if isinstance(v2,_torch.Tensor):
v2.detach_().requires_grad_()
@ihexx Thanks for the quick reply !
I tried to use your function but for me, it didn't work as intended. After using .detach_()
, at the end of next inner iteration, the diffopt.step(loss)
get a
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
maybe it's because i had to change your function so that v2.detach_().requires_grad_()
to v2=v2.detach().requires_grad_()
so that it isn't performed in-place (because views cannot be detached in-place).
I don't really have the time right now to dig any deeper so I'll believe you that it doesn't get any faster than the copy version !
@AntoineHX @ihexx We will look at this issue next week. In the meantime, in #15 we fixed some memory leak issues from #6. Could you please let us know if you are still seeing memory leaks after pulling a fresh version of the repo (and updating higher
in your env if you pip installed it)?
Hello. I've returned from leave and allocated some time to look into this issue over the next two weeks. Hopefully we'll make some progress and report back, or come back to you with questions.
@ihexx Note the patched modules already support this, i.e. you can do something like
module = ...
fmodule = higher.monkeypatch(module)
with torch.nograd():
for p in module.parameters():
p.add_(1) # modify original params so that fmodule params are different
module.load_state_dict(fmodule.state_dict()) # restore params from fmodule
Does this fit your needs on the module side of things?
I'll look at the optimizers later this week or early next week as there are other bugs that have higher priority.
Yeah that looks great and its exactly what I had in mind, thank you.
I'll close this issue for now since Higher now implements the core features I wanted, and the _detach()
idea was a dud in retrospect (no matter how I sliced the graph, it still needed to copy the parameters somewhere , so it didn't turn out any more efficient, and the slicing made the code much messier than just using this explicit copy style)
I'll try modifying some of my old code to use this in a couple of weeks, and re-open if I run into issues using it.
Thanks for the support :)
@ihexx I'd like to check that it works in the general sense before closing the issues. I also need to check that this works for optimizers, which I think may require more work, so I'll keep this issue open...
@ihexx Note the patched modules already support this, i.e. you can do something like
module = ... fmodule = higher.monkeypatch(module) with torch.nograd(): for p in module.parameters(): p.add_(1) # modify original params so that fmodule params are different module.load_state_dict(fmodule.state_dict()) # restore params from fmodule
Does this fit your needs on the module side of things?
I'll look at the optimizers later this week or early next week as there are other bugs that have higher priority.
Personally, i get the same error that i got when trying for the first time the detach_()
method that @ihexx proposed :
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
Also, i'm not sure to follow why this approach should be better than detaching from the graph. Aren't we making unnecessary copy to detach ? copying from fmodel to model and from model back to fmodel (if we want to continue, higher gradient tracking).
For now, i found that using the detach_()
method for DifferentiableOptimizer as well as performing some kind of detach for the fmodel seemed to work best :
tmp = fmodel.fast_params
fmodel._fast_params=[]
fmodel.update_params(tmp)
for p in fmodel.fast_params:
p.detach_().requires_grad_()
Even if it's seems to me that this still perform unnecessary copies of the parameters...
@ihexx Note the patched modules already support this, i.e. you can do something like
module = ... fmodule = higher.monkeypatch(module) with torch.nograd(): for p in module.parameters(): p.add_(1) # modify original params so that fmodule params are different module.load_state_dict(fmodule.state_dict()) # restore params from fmodule
Does this fit your needs on the module side of things?
I'll look at the optimizers later this week or early next week as there are other bugs that have higher priority.
Hi,
I am wondering are there any updates on the state_dict()
for the optimizers?
Hi there, it would be really useful if higher's API allowed blocking gradients at a point in the graph
Use Cases:
When working with Actor Critic models, I'd want to be able to allow gradients to flow through the critic when training it, but block them when using the critic to train the actor. It would be useful if i could make a detached copy of the critic for the second part.
When working with meta-gradients that don't go over the whole training loop (eg only use meta-gradients on K unrolled steps), it would be useful if we could detach the models so gradients don't flow past K.
Proposed API The simplest and most intuitive way in my opinion is to make the differentiable optimizers use with pytorch's existing
.state_dict()
and.load_state_dict()
, so we can call them on pytorch's normal modules and optimziers like:To continue training with meta gradients, we can then re-create the higher versions of these. The functional models already do this, but the differentiable optimizers don't. I've hacked together a patch that works for my use case (2-step unrolling), but I doubt would work in general.
An extension for convenience Lastly, a more convenient (but less intuitive) extension to this idea would be to have the old
.detach()
and.detach_()
methods for torch tensors apply to the Functional models. it didn't amke sense for nn.Modules since they weren't differentiable, but it would be so convenient to for Functional models since they (their parameters) are now differentiable too:detach()
should create a fresh copy of the whole functional model that is cut from the computation graph, which is really useful for the actor critic case:detach_()
should detach all param tensors in place. This would also be really useful for the and more efficient for the K step unroll use case since we don't need to deep copy back to a pytorch nn.Module model, then copy again to a functional model to continue. With this.detach_()
version we can just do that in place :Of course, these are just for convenience and efficiency since the
state_dict
methods could functionally do the same things, albeit slower