patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.37k stars 124 forks source link

Getting the derivative of the evolving state through events #458

Closed etienney closed 2 months ago

etienney commented 2 months ago

I want to use a terminating event similar to steady state event where the condition is checked on the derivative of the object that we are integrating. Looking at the code in diffrax._event it can only be done by reevaluating the terms, which is computationaly not efficient (since it has already been evaluated for the integration through solve.step). I need to tackle this issue. I need to do it on a RK, but let's take the exemple of Euler, I think the way to do it is to return terms.vf_prod(t0, y0, args, control) in the step, then get out of solver.step the value in diffrax._integrate and add the knowledge of the derivative to the class State ? I'm not sure it's super elegant though.

patrick-kidger commented 2 months ago

Right, this would definitely be a nice-to-have.

This ends up being pretty tricky, however, because many solvers don't actually compute this value. To explain: when we evaluate a term, we usually evaluate either term.vf(t0, y0, args) or term.vf_prod(t0, y0, args, control). The first one is the value of the vector field f(t, y, args). The second one is then multiplying that by the increment of the control, e.g. f(t, y, args) Δt for an ODE solve.

For events we typically want the first one, however for the integration itself we typically want the second one! And in general -- in particular for SDES --vf_prod does not imply also computing vf. For example when working with an SDE with diagonal noise, we may only compute the diagonal of the vector field (the only bit that is nonzero), not the whole thing.

For what it's worth, I've still done a first pass at this over on this branch: https://github.com/patrick-kidger/diffrax/tree/event-efficiency . This does the work to thread the information from AbstractSolver.step through to events, the integration itself, etc. The difficult bit is still finding a way to have every solver provide that information at all. If you'd like to think about that, then hopefully the above is a useful starting point!

etienney commented 2 months ago

Okay I understand the problem ! thanks for the link to the branch and your answer.

JesseFarebro commented 1 month ago

Thanks for providing the plumbing for this @patrick-kidger, wanted to add another use case: if I had a target vector field parameterized by a neural network and I want to perform flow matching I need to solve the ODE but also get the vector field for each of the discretization points. It would be nice to have support for this even if only for particular solvers.