Closed JonMuehlst closed 4 years ago
I need a bit more detail here. Can you explain what you mean, point me to example code, tell me what you tried and what went wrong?
Hi Jon, I am interested in gradient checkpointing. Can you give me your use case so I can investigate further?
The pytorch docs for torch.utils.checkpoint
state that
"Checkpointing doesn’t work with torch.autograd.grad()
, but only with torch.autograd.backward()
."
As higher
uses torch.autograd.grad()
in differentiable optimizers, it seems that checkpointing using the default utilities is not supported at this time. If this is an issue for you, please raise it in the pytorch repo.
In the meantime, and in the absence of further information about the issue from @JonMuehlst, I will close this. Please re-open if you have more detail.
I was wondering if there has been any movement on this. Suppose I want to compute meta-gradients through a long learning trajectory (i.e. a large number of inner loop updates). Meta-Grad checkpointing would help reduce the memory costs of doing so, since intermediate values necessary for computation of the meta-grads would not need to be stored in the original forward pass of the inner loop (i.e. they can be recomputed as needed when computing the meta-grads during the outer-loop backwards pass). Here is an example with MAML: memory efficient MAML using gradient checkpointing: https://github.com/dbaranchuk/memory-efficient-mamlhttps://github.com/dbaranchuk/memory-efficient-maml . But I was wondering if this would be possible with Higher because the Higher interface is so simple.
Is it possible to use gradient checkpointing with higher? Thanks