Open samuela opened 4 years ago
I agree, this would all be useful to have.
RE: Auxiliary solve output, see @jacobjinkelly's https://github.com/google/jax/pull/2574. Doing this right may require a more comprehensive solution adding side-effects to core JAX.
RE: Differentiating through the solver. This would also be great to have, along with other adjoint calculation methods (e.g., as described in Appendix B of https://arxiv.org/pdf/2001.04385.pdf and implemented in DifferentialEquations.jl).
Differentiating through the solver could be a little tricky to do in JAX, because we need fixed memory usage in order to compile something with XLA (i.e., jax.jit
). This could be achieved if we have some sort of maxiter
parameter and used scan
instead of while_loop
, but a naive implementation would be extremely memory intensive. To make this practical, I think we would need some form of gradient checkpointing (also useful for other adjoint methods).
RE: Auxiliary solve output, see @jacobjinkelly's #2574. Doing this right may require a more comprehensive solution adding side-effects to core JAX.
Wasn't aware of #2574. Def happy to see progress on that front! Although it seems inevitable that more and more auxiliary info will be desirable in the future. If I'm understanding #2574 correctly this would entail breaking API changes any time a new diagnostic output is added.
RE: Differentiating through the solver. This would also be great to have, along with other adjoint calculation methods (e.g., as described in Appendix B of https://arxiv.org/pdf/2001.04385.pdf and implemented in DifferentialEquations.jl).
Differentiating through the solver could be a little tricky to do in JAX, because we need fixed memory usage in order to compile something with XLA (i.e., jax.jit). This could be achieved if we have some sort of maxiter parameter and used scan instead of while_loop, but a naive implementation would be extremely memory intensive. To make this practical, I think we would need some form of gradient checkpointing (also useful for other adjoint methods).
Out of curiosity, why does XLA require bounded loops in order to do reverse-mode? Other IRs don't seem to have this limitation, eg. Relay which supports reverse-mode gradients through arbitrary loops. I believe their approach is based on https://arxiv.org/pdf/1803.10228.pdf. Julia also doesn't seem to have any issue with this.
If I'm understanding #2574 correctly this would entail breaking API changes any time a new diagnostic output is added.
I think the right way to do this is to switch to returning an object with an extensible list of fields in addition to the ODE solution, like SciPy's odeint or (newer) solve_ivp. This would allow for adding more auxiliary fields without breaking code.
Out of curiosity, why does XLA require bounded loops in order to do reverse-mode?
XLA has a general requirement that all memory allocation needs to be a statically known based on shapes. There's no dynamic allocation of arrays based on computation results. I don't know the exact reason for this requirement, but I imagine that it makes the compiler's job much easier.
XLA has a general requirement that all memory allocation needs to be a statically known based on shapes. There's no dynamic allocation of arrays based on computation results. I don't know the exact reason for this requirement, but I imagine that it makes the compiler's job much easier.
Mmm, I can see how that would make the compiler much simpler, but it seems quite limiting now. Would it be possible to implement this without jit
to circumvent the peculiarities of XLA? Doing so would be slower to be sure, but it would be better than nothing.
Mmm, I can see how that would make the compiler much simpler
It's not just about simplicity; being able to statically analyze the shapes and memory requirements enables a ton of optimizations (fusion, layout, remat, etc).
Would it be possible to implement this without jit to circumvent the peculiarities of XLA? Doing so would be slower to be sure, but it would be better than nothing.
Indeed, just write it with a regular Python while loop and don't jit it! I'm not sure this is something we'd want to maintain in the jax core repo, but it should be easy enough to stand up as a baseline.
Indeed, just write it with a regular Python while loop and don't jit it! I'm not sure this is something we'd want to maintain in the jax core repo, but it should be easy enough to stand up as a baseline.
Ok, will do! Could I just reuse the jax.experimental.ode
implementation of RK to do this? Is there any reason jax.disable_jit()
wouldn't do the trick?
It's not just about simplicity; being able to statically analyze the shapes and memory requirements enables a ton of optimizations (fusion, layout, remat, etc).
OTOH, what prevents XLA from intelligently inferring loop bounds and then applying optimizations where possible? This seems to be fairly standard practice in compilers.
Yeah, disable_jit
might work! Though perhaps not with vmap because of the other issue you identified.
Ok, gotcha!
Re: loop bounds, sure that's possible in some cases (though it's more of a JAX issue than an XLA issue because XLA doesn't do autodiff). Actually that was my intention with this code: if we can specialize on the trip count then we can do reverse-mode autodiff. (It failed some internal test that I can't remember, but I think it should work...)
But the point of while_loop is that you might not be able to predict the trip count. Something like an ode integrator, or anything you run to numerical convergence, is a good example: a priori you don't have a bound on the number of iterations. (There are some tricks you can play, like use recursive checkpointing and assume that loop trip counts are always bounded by 2**32 or something, but that's just in the weeds.)
Ok after chatting with @MarisaKirisame it sounds like the key difference here is that Relay supports ADTs and allocation in the IR. So in their case the AD of a while loop is just another while loop with tracing that runs in the IR. This is made fast thanks to all kinds of compiler tricks within the IR. My understanding is that when loop bounds can be inferred they switch on using optimized, shape-aware compilation.
If I'm understanding correctly this means that even an omnipotent JAX would not be able to do AD through while loops when compiling down to XLA. If that's true it seems as though the way forward to fix AD for loops is either
RE:Alternate solver choices. Do you mean implicit solvers like the Livermore Solver (LSODE; https://computing.llnl.gov/casc/nsde/pubs/u113855.pdf), which are defacto in scipy and can handle stiff equations? I started a project to write the LSODE for tensorflow and, let me say, it's tricky due to the many decisions made (page 58 of the link). There are a ton of heuristics that make things converge nicely and efficiently. There is the option however of stripping most of that logic out down to the fundamentals of an implicit ODE solver (which is also described fully pages 1-38 in that document).
@Joshuaalbert Yeah, I say the more options the better! I guess I view these sorts of things as good stress tests for tools like JAX and XLA. That flowchart looks pretty nasty, so it'd definitely be a challenge but I believe it should be do-able. Another option would be to support third-party solvers through FFI the same way scipy does. IIRC JAX has some kind of internal-ish way to do that sort of thing somewhere already.
An implicit solver for stiff ODEs would definitely be a welcome addition. We have a BDF method in TF-probability that could be a good start.
It would also be nice to expose an interface for reusing the ODE gradient definition(s), allowing users to bring their own solvers without needing to write new gradient rules. This would be similar to what we did with lax.custom_root
and lax.custom_linear_solve
.
I'd like to add my voice to those who would like to see differentiating through the solver and other adjoint calculation methods. That would be great.
Also, it appears that the current version of experimental.odeint cannot handle 64 bit data. The error message below stems from my call to odeint. Is it possible to alter odeint to handle 64 bit data?
tmp = odeint(MassSpring, latents, tp_train, rtol=rtol_make_data, atol=rtol_make_data)
File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/experimental/ode.py", line 158, in odeint
return _odeint_wrapper(func, rtol, atol, mxstep, y0, t, *args)
File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/api.py", line 150, in f_jitted
name=flat_fun.__name__)
File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/core.py", line 951, in call_bind
outs = primitive.impl(f, *args, **params)
File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/interpreters/xla.py", line 463, in _xla_call_impl
compiled_fun = _xla_callable(fun, device, backend, name, *map(arg_spec, args))
File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/linear_util.py", line 221, in memoized_fun
ans = call(fun, *args)
File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/interpreters/xla.py", line 480, in _xla_callable
fun, pvals, instantiate=False, stage_out=True, bottom=True)
File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 374, in trace_to_jaxpr
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/linear_util.py", line 150, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/experimental/ode.py", line 164, in _odeint_wrapper
out = _odeint(func, rtol, atol, mxstep, y0, ts, *args)
File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/custom_derivatives.py", line 455, in __call__
*args_flat, out_trees=out_trees)
File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/custom_derivatives.py", line 502, in _custom_vjp_call_bind
out_trees=out_trees)
File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 245, in process_custom_vjp_call
return fun.call_wrapped(*tracers)
File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/linear_util.py", line 150, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/experimental/ode.py", line 199, in _odeint
_, ys = lax.scan(scan_fun, init_carry, ts[1:])
File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/lax/lax_control_flow.py", line 848, in scan
jaxpr, consts, out_tree = _initial_style_jaxpr(f, in_tree, carry_avals + x_avals)
File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/lax/lax_control_flow.py", line 62, in _initial_style_jaxpr
wrapped_fun, in_pvals, instantiate=True, stage_out=False)
File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 374, in trace_to_jaxpr
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/linear_util.py", line 150, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/experimental/ode.py", line 189, in scan_fun
_, *carry = lax.while_loop(cond_fun, body_fun, [0] + carry)
File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/lax/lax_control_flow.py", line 228, in while_loop
body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(body_fun, in_tree, init_avals)
File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/lax/lax_control_flow.py", line 62, in _initial_style_jaxpr
wrapped_fun, in_pvals, instantiate=True, stage_out=False)
File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 374, in trace_to_jaxpr
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/linear_util.py", line 150, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/experimental/ode.py", line 179, in body_fun
next_y, next_f, next_y_error, k = runge_kutta_step(func_, y, f, t, dt)
File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/experimental/ode.py", line 116, in runge_kutta_step
k = lax.fori_loop(1, 7, body_fun, k)
File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/lax/lax_control_flow.py", line 171, in fori_loop
(lower, upper, init_val))
File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/lax/lax_control_flow.py", line 228, in while_loop
body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(body_fun, in_tree, init_avals)
File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/lax/lax_control_flow.py", line 62, in _initial_style_jaxpr
wrapped_fun, in_pvals, instantiate=True, stage_out=False)
File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 374, in trace_to_jaxpr
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/linear_util.py", line 150, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/lax/lax_control_flow.py", line 100, in while_body_fun
return lax.add(i, lax._const(i, 1)), upper, body_fun(i, x)
File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/experimental/ode.py", line 112, in body_fun
ft = func(yi, ti)
File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/experimental/ode.py", line 169, in <lambda>
func_ = lambda y, t: func(y, t, *args)
File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/linear_util.py", line 146, in call_wrapped
args, kwargs = next(gen)
File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/experimental/ode.py", line 51, in ravel_first_arg_
y = unravel(y_flat)
File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/flatten_util.py", line 29, in <lambda>
unravel_pytree = lambda flat: tree_unflatten(treedef, unravel_list(flat))
File "/home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/api.py", line 1291, in _vjp_pullback_wrapper
raise TypeError(msg.format(_dtype(a), dtype))
TypeError: Type of cotangent input to vjp pullback function (float64) does not match type of corresponding primal output (float32)
> /home/jboik/.virtualenvs/aai/lib/python3.7/site-packages/jax/api.py(1291)_vjp_pullback_wrapper()
-> raise TypeError(msg.format(_dtype(a), dtype))
@John-Boik Do you have a code sample you could share that reproduces the issue? It looks like the forward solve is not actually using float64 as you requested.
It was my fault. I was unintentionally passing odeint a matrix specified as float32.
@John-Boik and others, I would strongly advise against differentiating through an adaptive implict ODE solver like LSODE for performance reasons. Especially for a stiff problem. I would have time in June to help implement LSODE.
jax.experimental.host_callback
looks like could be a nice way to get auxiliary outputs out of ODE solvers. It's still very experimental, but it makes it possible to thread values of out jit
compiled code back into Python (e.g., see the example implementing a printer in https://github.com/google/jax/issues/3127).
@shoyer jax.experimental.host_callback
looks like everything I've ever dreamed of. Haven't tried it yet, but I think my life may be complete now.
I'd also like to add a feature request for
scipy.integrate.solve_bvp
Another feature that I'm realizing would be important
dense_output=True
option in scipy's solve_ivp
. All of the necessary math is done in the RK solve anyhow.I have implemented a BDF solver for stiff ODEs in JAX using TF-probability’s code. The implementation is pretty barebones right now and I still need to test JIT, VMAP. Also, I still need to implement adjoint gradient method like TF-probability’s which I’m planning to do soon. I tested it against SciPy’s VODE methods for stiff chemical kinetics problems and results look pretty good to me. The current implementation is here and I’m planning to open a PR here once I get the adjoint gradient method, JIT and VMAP etc working. Feedback and suggestions are welcome! Thanks for such a great framework.
- [ ] Returning the "dense output" spline, eg
dense_output=True
option in scipy'ssolve_ivp
. All of the necessary math is done in the RK solve anyhow.
To make this work in a way that is compatible with jit
, we would need to support picking a (max) static number of interpolation points for the spline. But I agree, this would be nice to have, particularly for the adjoint calculation because the stored interpolation could be used instead of integrating the solution backwards in time to recompute primal values.
To make this work in a way that is compatible with jit, we would need to support picking a (max) static number of interpolation points for the spline. But I agree, this would be nice to have, particularly for the adjoint calculation because the stored interpolation could be used instead of integrating the solution backwards in time to recompute primal values.
Yeah, XLA's allocation limitations present a bit of a challenge here.
I'd also like to add a feature request for
- [ ] A BVP solver, eg like
scipy.integrate.solve_bvp
I second this! I think this could be accomplished by using odeint
in parallel with a multiple shooting scheme. Otherwise, one could perform direct collocation with a chosen quadrature (e.g. Hermite-Simpson) — I've done this with JAX and IPOPT, and it works pretty well. These methods require optimising a sparse nonlinear programme, which is more amenable to constrained optimisers. And, I think extra consideration will be needed for how to handle the Jacobian and Hessian sparsity, as well as mesh refinement (which might be a problem for jit
).
btw, I have a rough implementation of some other RK solvers here, in case anyone has a use for them. :)
Just checking on this thread - we are very interested in getting fast gradients through a stiff ODE solution.
Some cool ideas in this thread - are there any updates since last year?
For those watching this thread -- check out Diffrax, which is a library doing pretty much everything discussed here. Other RK solvers, implicit solvers, discretise-then-optimise, etc.
The current implementation of Runge-Kutta with adjoint reverse-mode gradients is great, but there are a few things I still find myself missing, and I'd really love to help contribute, or just see in JAX one way or another.
y0
and the "hopefully close" y(t_0) from backtracking through the dynamics in the adjoint ODE.odeint
function with a vjp rule. Being able to select different solvers for the forward and adjoint passes would also be useful. Ideal solution would make arbitrary solver combos a cinch, eg. run RK on the forward pass, but Euler integration for the adjoint.