patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.46k stars 133 forks source link

inf values after triggering event function. #335

Open KhayrullevJokhongir opened 1 year ago

KhayrullevJokhongir commented 1 year ago

I am solving a simple problem below using DiscreteTerminatingEvent. Once the event is triggered, the integration stops, but the solver returns 'inf' values for the time steps following the event's trigger time. Is there a way to avoid this, so that the solver returns function evaluations only for the time steps before the event-trigger time, similar to how solve_ivp in SciPy does?

"import jax.numpy as jnp import matplotlib.pyplot as plt from diffrax import diffeqsolve, ODETerm, SaveAt, Tsit5, Dopri5, DiscreteTerminatingEvent

def vector_field(t, y, args): prey, predator = y α, β, γ, δ = args d_prey = α prey - β prey predator d_predator = -γ predator + δ prey predator return jnp.array([d_prey, d_predator])

'''Define the terminating event function with two conditions''' def terminating_event_fxn(state, args, **kwargs): prey_population = state.y[0] predator_population = state.y[1]

A = (prey_population < 5) | (predator_population > 15)
return A

'''Set up the ODE term, solver, and the initial conditions''' term = ODETerm(vector_field) solver = Dopri5() t0 = 0 t1 = 140 dt0 = 0.1 y0 = jnp.array([10.0, 10.0]) args = (0.1, 0.02, 0.4, 0.02) saveat = SaveAt(ts=jnp.linspace(t0, t1, 1000))

'''Define the terminating event''' terminating_event = DiscreteTerminatingEvent(terminating_event_fxn)

'''Solve the ODE with the terminating event''' sol = diffeqsolve(term, solver, t0, t1, dt0, y0, args=args, saveat=saveat, discrete_terminating_event=terminating_event)

'''Plot the results''' plt.plot(sol.ts, sol.ys[:, 0], label="Prey") plt.plot(sol.ts, sol.ys[:, 1], label="Predator") plt.legend() plt.show()

print(sol.ys[:, 0].size) print(sol.ts.shape)

"

patrick-kidger commented 1 year ago

I'm afraid not. All JAX arrays have to have a size known at compile time. However, the time of the event isn't known until runtime. As such Diffrax works by initialising an array of the appropriate size (here, of length given by saveat.ts) all filled with inf. Then it fills in this array as the integration progresses.

I hope that helps! :)

KhayrullevJokhongir commented 1 year ago

It helps, thank you for quick reply :)

KhayrullevJokhongir commented 1 year ago

Can I get some other values instead of inf, for example state of the system at last time step before the event triggered?

patrick-kidger commented 1 year ago

Not via Diffrax, but you could probably write some logic of your own to do that afterwards.