Closed jamie-cunningham closed 1 week ago
@keatincf I moved the JAX functionality to its own package and added an environment variable check for enabling JAX. Unfortunately using JAX seems to add significant overhead, resulting in a 10x slowdown vs numpy. I couldn't nail down why during profiling.
I also haven't directly addressed the NaN issue yet. I couldn't find any good guidance in the solver docs for how to deal with invalid solutions. Note that I used the diffrax
package for JAX solvers as it appears the JAX ode solver has been declared out of scope by the JAX team and no longer is actively developed. Any thoughts are welcome.
test_ode is failing because the _clip_state_dot_direct was moved out of ODEDynamics.
I spent a little bit of time trying to see if there was anything I could see that could be causing the slowdown due to JAX. I didn't see anything that would make me think JAX was recompling more than expected. Removing the jit decorators causes a further slow down, but only a second or two extra slowdown. I was worried that maybe the functions being compiled weren't pure enough be compiled properly.
I also tried to utilize a GPU enabled JAX, but that generated a larger slowdown. I'm not sure if that was partially caused by a driver/CUDA version mismatch (Ubuntu's repos have CUDA 12.4 and JAX is using CUDA 12.6).
It looks like the biggest slowdown might be using diffrax compared to scipy.integrate.
My guess is that JAX does not provide any major improvement on CPUs compared to numpy/scipy. It could provide better performance on a GPU/TPU, but we may need to adjust the code more to properly take advantage of JAX.
It should be ok to keep JAX in since you've noted that the use of JAX is experimental.
test_ode is failing because the _clip_state_dot_direct was moved out of ODEDynamics.
I spent a little bit of time trying to see if there was anything I could see that could be causing the slowdown due to JAX. I didn't see anything that would make me think JAX was recompling more than expected. Removing the jit decorators causes a further slow down, but only a second or two extra slowdown. I was worried that maybe the functions being compiled weren't pure enough be compiled properly.
I also tried to utilize a GPU enabled JAX, but that generated a larger slowdown. I'm not sure if that was partially caused by a driver/CUDA version mismatch (Ubuntu's repos have CUDA 12.4 and JAX is using CUDA 12.6).
It looks like the biggest slowdown might be using diffrax compared to scipy.integrate.
My guess is that JAX does not provide any major improvement on CPUs compared to numpy/scipy. It could provide better performance on a GPU/TPU, but we may need to adjust the code more to properly take advantage of JAX.
It should be ok to keep JAX in since you've noted that the use of JAX is experimental.
Yeah I also tried just using scipy.integrate
in a JIT wrapper but there were numpy vs jax issues that led me to diffrax. It doesn't seem like JAX handles third-party packages using numpy very well.
JAX flags have been moved to the class level. Class objects can now decide if they want to use JAX or not. JAX will only be used if it is available in the environment.
casting
Yeah Jax requires that input arrays to JIT functions are Jax types hence the casting. Since we aren't globally setting a Jax flag then it doesn't make sense to have Jax arrays in non-Jax compatible code. When there was a global flag we could just make everything a Jax type internally. Now we need to make sure that inputs to JIT functions are Jax types.
I can go back in and see if I can beef up the type hinting. Unfortunately checking a Jax type requires Jax to be available at runtime. We could make Jax a non-optional dependency again since we are being explicit about its use.
This change moves JAX functions to a separate module and allows toggling use of JAX via an environment variable.