patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.38k stars 124 forks source link

Restart after event #492

Open dkweiss31 opened 1 month ago

dkweiss31 commented 1 month ago

I'm wondering if you'd be interested in a PR that would allow for solves that are terminated via events to be restarted (after some user-supplied update function is called) until t1 is reached?

This is something very relevant for quantum applications: a typical trajectory can experience multiple events or "jumps" before the final time. If however this is not something you think might be of general interest, happy to implement this outside of diffrax.

patrick-kidger commented 1 month ago

I think this is probably best done outside of Diffrax -- users may wish to do arbitrary things across the jump, and I thnk it'd be hard to provide an API that supports them all.

I'll note that doing this restart is just a matter of wrapping diffrax.diffeqsolve in a lax.while_loop, so it should be straigtforward.

dkweiss31 commented 1 month ago

Understood, thank you!