patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.45k stars 131 forks source link

Returning user-specified info from diffeqsolve #299

Closed allen-adastra closed 7 months ago

allen-adastra commented 1 year ago

In some of my use cases, the function computing the derivatives also computes intermediate values that are valuable for debugging or use elsewhere. It would be great to be able to extract those values from diffeqsolve. It doesn't seem that this is a feature yet. Do let me know if it actually is!

Example of what I would like to do:

def dynamics(t, y, args):
   foo = physics_simulation(t, y, args) # some expensive physics-based quantity
   derivs = f(t, y, args, foo)
   info = {"foo": foo}
   return derivs, info
solution = diffrax.diffeqsolve(diffrax.ODETerm(dynamics),...)
ys = solution.ys
infos = solution.infos
patrick-kidger commented 1 year ago

Given that dynamics will be evaluated multiple times for different t and y, how would you like infos to be defined? Something like, the info value for each (t, y) pair in (sol.ts, sol.ys), all tree_map(stack, ...)'d together?

That's a reasonable feature request. Diffrax will be getting a fairly sweeping update (internally; the API won't change) soon. I'll try to add this then.

allen-adastra commented 1 year ago

Given that dynamics will be evaluated multiple times for different t and y, how would you like infos to be defined? Something like, the info value for each (t, y) pair in (sol.ts, sol.ys), all tree_map(stack, ...)'d together?

That's a reasonable feature request. Diffrax will be getting a fairly sweeping update (internally; the API won't change) soon. I'll try to add this then.

Yes, exactly, one info for every sol.ts. I imagine we'd want a flag to enable/disable this feature.

allen-adastra commented 8 months ago

Checking in on this again. If you have ideas for a general approach to implement this, I may be able to take a crack!

patrick-kidger commented 8 months ago

Hey there!

So if the auxiliary information is something you're happy to recompute, then this can now be done doing something like this:

import diffrax as dfx
import jax.numpy as jnp

terms = dfx.ODETerm(lambda t, y, args: -y)
solver = dfx.Tsit5()
y0 = jnp.array(1.)
saveat_fn = lambda t, y, args: (y, y + 1)  # save both y and some auxiliary information
sol = dfx.diffeqsolve(terms, solver, t0=0, t1=1, dt0=0.1, y0=y0, saveat=dfx.SaveAt(ts=[0, 0.5, 1], fn=saveat_fn))
ys, aux = sol.ys
print(aux)

If the goal is to avoid recomputing the auxiliary information, then I'm realising this might get a lot trickier. For example, for many solvers, we will never actually evaluate the vector field at the output points at all! Diffrax works by making steps and then interpolating between them where necessary. So I'm not sure it would be possible to cleanly implement that in general.

allen-adastra commented 7 months ago

Honestly this is great; it's a factor of 2 recomputation at most. Really I just wanted a clean API for it.