patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.39k stars 123 forks source link

Question about BacksolveAdjoint through SemiImplicitEuler solver #201

Open Chenghao-Wu opened 1 year ago

Chenghao-Wu commented 1 year ago

I am testing the adjoint method to calculate the gradients from a SemiImplicitEuler solver. I met errors when calculate the gradients using BacksolveAdjoint method. Here is a working example. It would be great to have some suggestions.

Thank you in advance!

` from diffrax import diffeqsolve, ODETerm, SemiImplicitEuler, SaveAt, BacksolveAdjoint import jax.numpy as jnp from jax import grad from matplotlib import pyplot as plt

def drdt(t, v, args): return v

def dvdt(t, r, args): return -args[0]*(r-args[1])

terms =(ODETerm(drdt),ODETerm(dvdt)) solver = SemiImplicitEuler() y0 = (jnp.array([1.0]),jnp.array([0.0])) saveat = SaveAt(ts=jnp.arange(0,30,0.1))

def loss(y0): solution = diffeqsolve(terms, solver, t0=0, t1=30, dt0=0.0001, y0=y0, args=[1.0,0.0], saveat=saveat,max_steps=10000000,adjoint=BacksolveAdjoint()) return jnp.sum(solution.ys[0]) grads = grad(loss)(y0) print(grads) `

here is the error message:

Traceback (most recent call last): File "test_harmonic.py", line 23, in <module> grads = grad(loss)(y0) File "/home/zwq2834/anaconda3/envs/diffsim_jax/lib/python3.8/site-packages/equinox-0.9.2-py3.8.egg/equinox/grad.py", line 482, in fn_bwd_wrapped out = fn_bwd(residuals, grad_diff_array_out, vjp_arg, *args, **kwargs) File "/home/zwq2834/development/DiffSim_Jax/diffrax/diffrax/adjoint.py", line 394, in _loop_backsolve_bwd state, _ = _scan_fun(state, val0, first=True) File "/home/zwq2834/development/DiffSim_Jax/diffrax/diffrax/adjoint.py", line 332, in _scan_fun _sol = diffeqsolve( File "/home/zwq2834/anaconda3/envs/diffsim_jax/lib/python3.8/site-packages/equinox-0.9.2-py3.8.egg/equinox/jit.py", line 82, in __call__ return __self._fun_wrapper(False, args, kwargs) File "/home/zwq2834/anaconda3/envs/diffsim_jax/lib/python3.8/site-packages/equinox-0.9.2-py3.8.egg/equinox/jit.py", line 78, in _fun_wrapper dynamic_out, static_out = self._cached(dynamic, static) File "/home/zwq2834/anaconda3/envs/diffsim_jax/lib/python3.8/site-packages/equinox-0.9.2-py3.8.egg/equinox/jit.py", line 30, in fun_wrapped out = fun(*args, **kwargs) File "/home/zwq2834/development/DiffSim_Jax/diffrax/diffrax/integrate.py", line 858, in diffeqsolve final_state, aux_stats = adjoint.loop( File "/home/zwq2834/development/DiffSim_Jax/diffrax/diffrax/adjoint.py", line 499, in loop final_state, aux_stats = _loop_backsolve( File "/home/zwq2834/anaconda3/envs/diffsim_jax/lib/python3.8/site-packages/equinox-0.9.2-py3.8.egg/equinox/grad.py", line 509, in __call__ out = self.fn_wrapped( File "/home/zwq2834/anaconda3/envs/diffsim_jax/lib/python3.8/site-packages/equinox-0.9.2-py3.8.egg/equinox/grad.py", line 443, in fn_wrapped out = self.fn(vjp_arg, *args, **kwargs) File "/home/zwq2834/development/DiffSim_Jax/diffrax/diffrax/adjoint.py", line 250, in _loop_backsolve return self._loop_fn( File "/home/zwq2834/development/DiffSim_Jax/diffrax/diffrax/integrate.py", line 497, in loop final_state = bounded_while_loop( File "/home/zwq2834/development/DiffSim_Jax/diffrax/diffrax/misc/bounded_while_loop.py", line 125, in bounded_while_loop return lax.while_loop(cond_fun, _body_fun, init_val) File "/home/zwq2834/development/DiffSim_Jax/diffrax/diffrax/misc/bounded_while_loop.py", line 118, in _body_fun _new_val = body_fun(_val, inplace) File "/home/zwq2834/development/DiffSim_Jax/diffrax/diffrax/integrate.py", line 137, in body_fun (y, y_error, dense_info, solver_state, solver_result) = solver.step( File "/home/zwq2834/development/DiffSim_Jax/diffrax/diffrax/solver/semi_implicit_euler.py", line 42, in step y0_1, y0_2 = y0 ValueError: too many values to unpack (expected 2)

patrick-kidger commented 1 year ago

Right! So this is actually expected, although the error message could certainly be improved.

When you use BacksolveAdjoint, you're saying that you'd like to solve another ODE system backwards in time. This involves (a) solving the original ODE backwards-in-time, but also (b) solving the adjoint system (which propagates the gradients backwards in time).

In particular, this no longer has the partitioned structure that SemiImplicitEuler expects. (And by default, BacksolveAdjoint uses the same solver on the backward pass as on the forward pass.)

The "correct" fix here would be to define some new solver that is able to handle the combined partitioned+adjoint system arising from backpropagating through a partitioned system. You could then pass BacksolveAdjoint(solver=MyNewSolver()) and things would work. This would actually be relatively straightforward if you know what you're doing.


For the sake of improving the error message, I've marked this as a bug. And to have this new solver be defined and actually used automatically, I've marked this as a feature.

In the mean time I'd probably recommend that you use the default adjoint method instead, which is usually a better choice than BacksolveAdjoint anyway.

Other miscellaneous thoughts: