Higher Order Reverse Derivatives Efficiently - Automatic Differentiation library based on the paper "Provably correct, asymptotically efficient, higher-order reverse-mode automatic differentiation"
Try to implement checkpointing (inserting recomputation to trade-off computation vs memory use) and then automatic checkpointing, which is what pytorch/JAX users now reportedly need and can't get.
We have an old discussion starting with @tomjaguarpaw sketching an extension of the POPL paper with checkpointing https://github.com/Mikolaj/mostly-harmless/discussions/20. We also had two variants of (things related to) checkpoint implemented at some point due to a peak of popular interest, but it bit-rotted before anybody found it interesting again and before any benchmarks for it were written and was removed when horde-ad got simplified.
I wonder if in the current mode of operation where we do reverse differentiation symbolically instead on using the real inputs, the memory leaks problems posed in the discussion are gone. More generally, I wonder how checkpointing in the current mode would differ from what Tom describes and whether pytorch/JAX do checkpointing in both modes of operation.
I'd advise against implementing it again before we have an interest proven by tests and benchmarks written by the interested parties.
Try to implement checkpointing (inserting recomputation to trade-off computation vs memory use) and then automatic checkpointing, which is what pytorch/JAX users now reportedly need and can't get.
We have an old discussion starting with @tomjaguarpaw sketching an extension of the POPL paper with checkpointing https://github.com/Mikolaj/mostly-harmless/discussions/20. We also had two variants of (things related to) checkpoint implemented at some point due to a peak of popular interest, but it bit-rotted before anybody found it interesting again and before any benchmarks for it were written and was removed when horde-ad got simplified.
I wonder if in the current mode of operation where we do reverse differentiation symbolically instead on using the real inputs, the memory leaks problems posed in the discussion are gone. More generally, I wonder how checkpointing in the current mode would differ from what Tom describes and whether pytorch/JAX do checkpointing in both modes of operation.
I'd advise against implementing it again before we have an interest proven by tests and benchmarks written by the interested parties.