patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.46k stars 133 forks source link

Making subsaveat consider previous saves and not just the current one #472

Open etienney opened 4 months ago

etienney commented 4 months ago

When using Saveat we have the option to call a "fn" defined in the doc as

fn: A function fn(t, y, args) which specifies what to save into sol.ys when using t0, t1, ts or steps. Defaults to fn(t, y, args) -> y, so that the evolving solution is saved. For example this can be useful to save only statistics of your solution, so as to reduce memory usage.

but why not changing it to be something like fn(save_state, t, y, args) in _integrate.py where it is called line 218 under

def _save(
    t: FloatScalarLike,
    y: PyTree[Array],
    args: PyTree,
    fn: Callable,
    save_state: SaveState,
) -> SaveState:
    ts = save_state.ts
    ys = save_state.ys
    save_index = save_state.save_index

    ts = ts.at[save_index].set(t)
    ys = jtu.tree_map(lambda ys_, y_: ys_.at[save_index].set(y_), ys, fn(t, y, args))
    save_index = save_index + 1

    return eqx.tree_at(
        lambda s: [s.ts, s.ys, s.save_index], save_state, [ts, ys, save_index]
    )

we could then save a function that depends not only on the last state but also on the previous ones... (which can be useful in my case and I believe is more general than current version "for free") we ought to also modify the default function for fn in _saveat.py

def save_y(t, y, args):
    return y

to

def save_y(save_state, t, y, args):
    return y

and I think it should be okay ? Of course I could do it myself and work with such a diffrax but i'm working on a library which is dependent on a library which itself is dependent on diffrax, so I'm interested for it to be in the "real" diffrax, so that my library keeps up with the version of the library above me haha

patrick-kidger commented 4 months ago

So save_state is an internal implementation detail that I don't think should be exposed to users -- it may change from release to release.

More generally, this kind of dependency might introduce dependencies that I'm not confident will play well with autodifferentiation. Diffrax does some fairly complicated things here to be able to fill in buffers in an efficient manner during the iteration, whilst also remaining autodiff-friendly.

What's your use case?

etienney commented 4 months ago

Okay save_statemay be too much to be given to users but maybe save_index, save_state.ts and save_state.ys would be nice ?

The usecase is to compute functions with entries like f(t_n, y_n, argsn, t{n-1}, y{n-1}, args{n-1}, n) (or with any (n-i)) with n the nth iteration (t_n would be ts[save_index] with current formalism), along the simulation.

patrick-kidger commented 4 months ago

I think for this case it'd be best to just save output as normal, and then do an additional scan over the saved values (after the diffeqsolve) to compute your desired output.

etienney commented 4 months ago

Then you would not be able to use your saved output in some event though. Of course this can be done at the end, but the idea is to do it along for such a reason.