facebookresearch / higher

higher is a pytorch library allowing users to obtain higher order gradients over losses spanning training loops rather than individual training steps.
Apache License 2.0
1.59k stars 123 forks source link

Question about gradient checkpointing #22

Closed JonMuehlst closed 4 years ago

JonMuehlst commented 4 years ago

Is it possible to use gradient checkpointing with higher? Thanks

egrefen commented 4 years ago

I need a bit more detail here. Can you explain what you mean, point me to example code, tell me what you tried and what went wrong?

MarisaKirisame commented 4 years ago

Hi Jon, I am interested in gradient checkpointing. Can you give me your use case so I can investigate further?

egrefen commented 4 years ago

The pytorch docs for torch.utils.checkpoint state that "Checkpointing doesn’t work with torch.autograd.grad(), but only with torch.autograd.backward()."

As higher uses torch.autograd.grad() in differentiable optimizers, it seems that checkpointing using the default utilities is not supported at this time. If this is an issue for you, please raise it in the pytorch repo.

In the meantime, and in the absence of further information about the issue from @JonMuehlst, I will close this. Please re-open if you have more detail.

blake-camp commented 2 years ago

I was wondering if there has been any movement on this. Suppose I want to compute meta-gradients through a long learning trajectory (i.e. a large number of inner loop updates). Meta-Grad checkpointing would help reduce the memory costs of doing so, since intermediate values necessary for computation of the meta-grads would not need to be stored in the original forward pass of the inner loop (i.e. they can be recomputed as needed when computing the meta-grads during the outer-loop backwards pass). Here is an example with MAML: memory efficient MAML using gradient checkpointing: https://github.com/dbaranchuk/memory-efficient-mamlhttps://github.com/dbaranchuk/memory-efficient-maml . But I was wondering if this would be possible with Higher because the Higher interface is so simple.