See the discussion in #122 for a summary of the bug including technical details. This PR fixes the bug by taking the following actions (described in this comment):
The bug for method='jax_odeint' has been fixed by updating the internal time-value handling of jax_odeint.
The bug for method being a diffrax solver has been fixed with minimal changes. (It was using the problematic time-handling functions, but this was actually unnecessary as diffrax is able to take care of these things itself, so dynamics doesn't need to deal with it.)
The bug for method being a dynamics-based fixed-step solver is now avoided by making it so that the automatic jitting behaviour is only triggered in the above two cases. (There isn't a simple "fix" for making the jitting work for these solvers.)
Details and comments
The following changes have been made:
Tests for Solver have been added that trigger the errors described in #122 (which are now passing).
For diffrax solver fix:
Removed unnecessary call to merge_t_args
For jax_odeint:
Modified merge_t_args and trim_t_args to always append/remove the endpoints. The original logic was to check if any end points of t_eval and t_span were the same, and if yes it wouldn't append/remove time points where appropriate. I realized that this special handling is not actually necessary, it is fine to always add/remove the endpoints.
Added merge_t_args_jax and trim_t_args_jax to be the JAX-compilable versions of these functions. To do compilable validation, most validation checks in merge_t_args_jax are now written with JAX conditionals. If a validation check fails, the output will be full of nan values. Unfortunately this is the only way to "raise an error" in a JAX compilable way.
Modified tests for merge_t_args and trim_t_args to check for the updated behaviour.
Added tests for merge_t_args_jax and trim_t_arg_jax, utilizing inheritance from the non-JAX test cases to limit number of lines.
Updated jax_odeint to use the JAX versions of these functions.
Added test case for jax_odeint that validates we can compile/differentiate functions that take both t_span and t_eval as inputs.
To avoid bug for fixed step solvers:
Updated Solver.solve to only do the automatic compilation when the method is either jax_odeint or a diffrax solver.
Summary
Closes #122.
See the discussion in #122 for a summary of the bug including technical details. This PR fixes the bug by taking the following actions (described in this comment):
method='jax_odeint'
has been fixed by updating the internal time-value handling ofjax_odeint
.method
being a diffrax solver has been fixed with minimal changes. (It was using the problematic time-handling functions, but this was actually unnecessary asdiffrax
is able to take care of these things itself, so dynamics doesn't need to deal with it.)method
being a dynamics-based fixed-step solver is now avoided by making it so that the automatic jitting behaviour is only triggered in the above two cases. (There isn't a simple "fix" for making the jitting work for these solvers.)Details and comments
The following changes have been made:
Solver
have been added that trigger the errors described in #122 (which are now passing).merge_t_args
merge_t_args
andtrim_t_args
to always append/remove the endpoints. The original logic was to check if any end points oft_eval
andt_span
were the same, and if yes it wouldn't append/remove time points where appropriate. I realized that this special handling is not actually necessary, it is fine to always add/remove the endpoints.merge_t_args_jax
andtrim_t_args_jax
to be the JAX-compilable versions of these functions. To do compilable validation, most validation checks inmerge_t_args_jax
are now written with JAX conditionals. If a validation check fails, the output will be full ofnan
values. Unfortunately this is the only way to "raise an error" in a JAX compilable way.merge_t_args
andtrim_t_args
to check for the updated behaviour.merge_t_args_jax
andtrim_t_arg_jax
, utilizing inheritance from the non-JAX test cases to limit number of lines.jax_odeint
to use the JAX versions of these functions.jax_odeint
that validates we can compile/differentiate functions that take botht_span
andt_eval
as inputs.Solver.solve
to only do the automatic compilation when the method is eitherjax_odeint
or a diffrax solver.