patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.37k stars 124 forks source link

Making sure that the backward pass is using Algorithm-2 from the paper 'Efficient and Accurate Gradients for Neural SDEs' #300

Open mpereira30 opened 1 year ago

mpereira30 commented 1 year ago

Looking at the description of the arguments of diffeqsolve closely, I realized that irrespective of choosing the solver to be diffrax.ReversibleHeun, the backward pass by default is diffrax.RecursiveCheckpointAdjoint which is also known as discretize-then-optimize. I want to specifically use the Algorithm-2 for the backward pass from the paper 'Efficient and Accurate Gradients for Neural SDEs' which leverages the algebraically reversible nature of the reversible Heun solver. It is not clear to me which adjoint should I choose from these (https://docs.kidger.site/diffrax/api/adjoints/) because based on the description, none of them seem to be using the algebraic reversibility of the solver.

patrick-kidger commented 1 year ago

Hey there! This is a great observation. Indeed, Diffrax does not currently implement algebraically reversible backpropagation.

For now, I believe torchsde is the only library to implement this. (Although the library is not entirely maintained any more, it should mostly work, I think.)

Alternatively, if you're feeling ambitious, I'd be happy to shepherd a pull request adding this to Diffrax. I imagine it should be doable by (consider these notes for my future self as much as anyone else):

  1. creating a new AbstractReversibleSolver class, and having ReversibleHeun inherit from it. Have all reversible solvers provide a corresponding adjoint solver;
  2. creating a new ReversibleAdjoint that checks you're passing in a reversible solver;
  3. on the forward pass, record the location of every timestep (by adding an extra SubSaveAt(steps=True, fn=lambda t, y, args: None)?);
  4. calling diffeqsolve on the backward pass, with stepsize_controller=StepTo(ts=<ts saved on the forward pass>), and solver=<adjoint solver>.
AddisonHowe commented 3 months ago

This would be a really neat feature I think. Has there been any progress on this? I'm fairly new to JAX but have worked some with diffrax and equinox, and am trying to decide if I'm familiar enough to take a shot at this.

patrick-kidger commented 3 months ago

No progress on this! If you think it'd be interesting then go ahead and give it a try!