patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.44k stars 130 forks source link

Solving NODE with implicit adjoint & steady state fails cause event occurred #520

Open daveekr opened 2 weeks ago

daveekr commented 2 weeks ago

Hello,

I am trying to solve a neural ODE using the diffrax.ImplicitAdjoint(), adaptive step sizing and the steady state event as proposed in the documentation. However, I get this error message which I do not fully understand.

equinox.EquinoxRuntimeError: Terminating differential equation solve because an event occurred.

I have t0 = 0 and t1 = jnp.inf with this event

cond_fn = diffrax.steady_state_event(rtol=1e-4, atol=1e-6)
event = diffrax.Event(cond_fn)

From my understanding: I actually want the solution at the steady state (diffrax.SaveAt(t1=True)), so I would expect that the event is fired once the steady state is reached and consequently the differential equation solve is being stopped. Otherwise the solver would solve forever, since t1 = jnp.inf, unless this was not capped by max_steps. Am I correct?

Any help is much appreciated, thanks!

patrick-kidger commented 2 weeks ago

Can you provide a MWE demonstrating the issue?