Open KhayrullevJokhongir opened 1 year ago
I'm afraid not. All JAX arrays have to have a size known at compile time. However, the time of the event isn't known until runtime. As such Diffrax works by initialising an array of the appropriate size (here, of length given by saveat.ts
) all filled with inf
. Then it fills in this array as the integration progresses.
I hope that helps! :)
It helps, thank you for quick reply :)
Can I get some other values instead of inf, for example state of the system at last time step before the event triggered?
Not via Diffrax, but you could probably write some logic of your own to do that afterwards.
I am solving a simple problem below using DiscreteTerminatingEvent. Once the event is triggered, the integration stops, but the solver returns 'inf' values for the time steps following the event's trigger time. Is there a way to avoid this, so that the solver returns function evaluations only for the time steps before the event-trigger time, similar to how solve_ivp in SciPy does?
"import jax.numpy as jnp import matplotlib.pyplot as plt from diffrax import diffeqsolve, ODETerm, SaveAt, Tsit5, Dopri5, DiscreteTerminatingEvent
def vector_field(t, y, args): prey, predator = y α, β, γ, δ = args d_prey = α prey - β prey predator d_predator = -γ predator + δ prey predator return jnp.array([d_prey, d_predator])
'''Define the terminating event function with two conditions''' def terminating_event_fxn(state, args, **kwargs): prey_population = state.y[0] predator_population = state.y[1]
'''Set up the ODE term, solver, and the initial conditions''' term = ODETerm(vector_field) solver = Dopri5() t0 = 0 t1 = 140 dt0 = 0.1 y0 = jnp.array([10.0, 10.0]) args = (0.1, 0.02, 0.4, 0.02) saveat = SaveAt(ts=jnp.linspace(t0, t1, 1000))
'''Define the terminating event''' terminating_event = DiscreteTerminatingEvent(terminating_event_fxn)
'''Solve the ODE with the terminating event''' sol = diffeqsolve(term, solver, t0, t1, dt0, y0, args=args, saveat=saveat, discrete_terminating_event=terminating_event)
'''Plot the results''' plt.plot(sol.ts, sol.ys[:, 0], label="Prey") plt.plot(sol.ts, sol.ys[:, 1], label="Predator") plt.legend() plt.show()
print(sol.ys[:, 0].size) print(sol.ts.shape)
"