Fixes the merge_t_args_jax and trim_t_results_jax functions. I think this is a final comprehensive solution to the problem that led to #147.
Details and comments
The bug stems from the following annoying technical issue that we've dealt with for a while:
We've adopted scipy's solve_ivp time argument structure for specifying ODE problems in solve_ode and solve_lmde. t_span specifies the integration interval, and t_eval is an optional argument which tells the solver to only return the state at specific points in time (otherwise it will return it at whatever times the solvers stepped to).
JAX's odeint function takes a single time argument, ts, which is essentially like t_eval, except the first and last times are interpreted as the whole integration interval. Hence, interfacing odeint into our solve_ode/solve_lmde functions requires translating the t_span/t_eval arguments into a single list of times to pass to odeint.
One annoying peculiarity of odeint, which is what causes this bug, is that if the first or last entry of ts are repeated, the solution will have nan appearing in some entries (I'm forgetting which exactly). This is problematic for us because the endpoints of t_span are very commonly the endpoints of t_eval. Hence, merging t_span and t_eval into a single array by prepending/appending the entries of t_span onto t_eval will commonly result in repeated entries, causing this nan error to occur. (Note as well that we can't simply use if statements to check if these endpoints are equal and choose to append or not append them: this will violate the JAX requirement that the inputs/outputs to function are fixed-shape :D.)
I thought I had fixed this problem in #125: where the nan appear in the solution if duplicates are present is predictable, so trim_t_results_jax was written to avoid these values. However, #147 was created after discovering that this change made the calculation of gradients of solutions return nan. I think this is because the gradient computation of odeint will always utilize all intermediate computed solution values, so even though we were removing the nan's, the gradient computation was still hitting them, which is something we couldn't avoid. This PR fixes this by changing the "merging strategy" of merge_t_args and trim_t_results_jax. Rather than just appending t_span to the ends of t_eval, we append them, then, conditional on if the endpoints are equal, modify the resulting combined array so that there are no duplicates, then correctly choosing the right entry when trimming the results. Importantly, this will not fundamentally impact the solver behaviour.
Summary
Closes #147
Fixes the
merge_t_args_jax
andtrim_t_results_jax
functions. I think this is a final comprehensive solution to the problem that led to #147.Details and comments
The bug stems from the following annoying technical issue that we've dealt with for a while:
solve_ivp
time argument structure for specifying ODE problems insolve_ode
andsolve_lmde
.t_span
specifies the integration interval, andt_eval
is an optional argument which tells the solver to only return the state at specific points in time (otherwise it will return it at whatever times the solvers stepped to).odeint
function takes a single time argument,ts
, which is essentially liket_eval
, except the first and last times are interpreted as the whole integration interval. Hence, interfacingodeint
into oursolve_ode
/solve_lmde
functions requires translating thet_span
/t_eval
arguments into a single list of times to pass toodeint
.odeint
, which is what causes this bug, is that if the first or last entry ofts
are repeated, the solution will havenan
appearing in some entries (I'm forgetting which exactly). This is problematic for us because the endpoints oft_span
are very commonly the endpoints oft_eval
. Hence, mergingt_span
andt_eval
into a single array by prepending/appending the entries oft_span
ontot_eval
will commonly result in repeated entries, causing thisnan
error to occur. (Note as well that we can't simply useif
statements to check if these endpoints are equal and choose to append or not append them: this will violate the JAX requirement that the inputs/outputs to function are fixed-shape :D.)I thought I had fixed this problem in #125: where the
nan
appear in the solution if duplicates are present is predictable, sotrim_t_results_jax
was written to avoid these values. However, #147 was created after discovering that this change made the calculation of gradients of solutions returnnan
. I think this is because the gradient computation ofodeint
will always utilize all intermediate computed solution values, so even though we were removing thenan
's, the gradient computation was still hitting them, which is something we couldn't avoid. This PR fixes this by changing the "merging strategy" ofmerge_t_args
andtrim_t_results_jax
. Rather than just appendingt_span
to the ends oft_eval
, we append them, then, conditional on if the endpoints are equal, modify the resulting combined array so that there are no duplicates, then correctly choosing the right entry when trimming the results. Importantly, this will not fundamentally impact the solver behaviour.