Open mpereira30 opened 1 year ago
Hey there! This is a great observation. Indeed, Diffrax does not currently implement algebraically reversible backpropagation.
For now, I believe torchsde is the only library to implement this. (Although the library is not entirely maintained any more, it should mostly work, I think.)
Alternatively, if you're feeling ambitious, I'd be happy to shepherd a pull request adding this to Diffrax. I imagine it should be doable by (consider these notes for my future self as much as anyone else):
AbstractReversibleSolver
class, and having ReversibleHeun
inherit from it. Have all reversible solvers provide a corresponding adjoint solver;ReversibleAdjoint
that checks you're passing in a reversible solver;SubSaveAt(steps=True, fn=lambda t, y, args: None)
?);diffeqsolve
on the backward pass, with stepsize_controller=StepTo(ts=<ts saved on the forward pass>)
, and solver=<adjoint solver>
.This would be a really neat feature I think. Has there been any progress on this? I'm fairly new to JAX but have worked some with diffrax and equinox, and am trying to decide if I'm familiar enough to take a shot at this.
No progress on this! If you think it'd be interesting then go ahead and give it a try!
Looking at the description of the arguments of
diffeqsolve
closely, I realized that irrespective of choosing the solver to bediffrax.ReversibleHeun
, the backward pass by default isdiffrax.RecursiveCheckpointAdjoint
which is also known as discretize-then-optimize. I want to specifically use the Algorithm-2 for the backward pass from the paper 'Efficient and Accurate Gradients for Neural SDEs' which leverages the algebraically reversible nature of the reversible Heun solver. It is not clear to me whichadjoint
should I choose from these (https://docs.kidger.site/diffrax/api/adjoints/) because based on the description, none of them seem to be using the algebraic reversibility of the solver.