Closed allen-adastra closed 7 months ago
Given that dynamics
will be evaluated multiple times for different t
and y
, how would you like infos
to be defined? Something like, the info
value for each (t, y)
pair in (sol.ts, sol.ys)
, all tree_map(stack, ...)
'd together?
That's a reasonable feature request. Diffrax will be getting a fairly sweeping update (internally; the API won't change) soon. I'll try to add this then.
Given that
dynamics
will be evaluated multiple times for differentt
andy
, how would you likeinfos
to be defined? Something like, theinfo
value for each(t, y)
pair in(sol.ts, sol.ys)
, alltree_map(stack, ...)
'd together?That's a reasonable feature request. Diffrax will be getting a fairly sweeping update (internally; the API won't change) soon. I'll try to add this then.
Yes, exactly, one info
for every sol.ts
. I imagine we'd want a flag to enable/disable this feature.
Checking in on this again. If you have ideas for a general approach to implement this, I may be able to take a crack!
Hey there!
So if the auxiliary information is something you're happy to recompute, then this can now be done doing something like this:
import diffrax as dfx
import jax.numpy as jnp
terms = dfx.ODETerm(lambda t, y, args: -y)
solver = dfx.Tsit5()
y0 = jnp.array(1.)
saveat_fn = lambda t, y, args: (y, y + 1) # save both y and some auxiliary information
sol = dfx.diffeqsolve(terms, solver, t0=0, t1=1, dt0=0.1, y0=y0, saveat=dfx.SaveAt(ts=[0, 0.5, 1], fn=saveat_fn))
ys, aux = sol.ys
print(aux)
If the goal is to avoid recomputing the auxiliary information, then I'm realising this might get a lot trickier. For example, for many solvers, we will never actually evaluate the vector field at the output points at all! Diffrax works by making steps and then interpolating between them where necessary. So I'm not sure it would be possible to cleanly implement that in general.
Honestly this is great; it's a factor of 2 recomputation at most. Really I just wanted a clean API for it.
In some of my use cases, the function computing the derivatives also computes intermediate values that are valuable for debugging or use elsewhere. It would be great to be able to extract those values from
diffeqsolve
. It doesn't seem that this is a feature yet. Do let me know if it actually is!Example of what I would like to do: