Open kokbent opened 10 months ago
The first thing that jumps out is that you don't appear to be explicitly JIT'ing your computation. Diffrax already does this for you internally for the most part, but even so best practice is to put an equinox.filter_jit
on des
.
The second is that it looks like beta
and gamma
might be Python float
s rather than JAX arrays, in which case I suspect things are recompiling every time. Make them NumPy or JAX arrays. (When using equinox.filter_jit
, the rule is that things will recompile if a JAX/NumPy array changes shape or dtype, and if anything else changes in any way at all.)
Hi Patrick, thanks for the response. I've jitted my des
function as you suggested. For the beta
and gamma
, making a jax array in the first part of the code doesn't seem to have much effect (they are only used to generate a random sample). Within the sampling model sir()
, it's handled by numpyro
and i believe all the sampled parameters should be in some form of JAX traceable arrays. And the MCMC is still very slow. I probably should also put the issue to numpyro.
You can double-check whether recompilation is happening with equinox.debug.assert_max_traces
, by the way.
So as the title says, I've been trying to fit my SIR ODE model using NUTS on GPU. However, the fit was extremely slow when compared to CPU. I'm using
jax
andnumpyro
to do the fitting. I ran this on Google colab:This is not an issue specific to
diffrax
, I had the same problem usingodeint
as my ODE solver too. I've searched through the internet, and seems like similar issue (butodeint
) was reported in JAX: Gradients with odeint slow on GPU #5006. According to one of the reply: it seems like the tight loop structure inodeint
is not XLA GPU friendly. Given that I have seen similar issue when usingdiffeqsolve
, I guess that it also uses similar technique and suffer from similar issue? The question then is, is there any possible way to circumvent the problem within thediffrax
package, perhaps another type of implementation?Here's the code I use: