Closed fehiepsi closed 4 years ago
A simpler repro code
def dz_dt(z, t):
return jnp.stack([z[0], z[1]])
def f(z):
y = odeint(dz_dt, z, jnp.arange(10.))
return jnp.sum(y)
jax.grad(f)(jnp.ones(2))
It seems to me that the indices 0
, 1
cause the issue.
Ah, this is indeed because of #3562. Thanks for catching it!
Unfortunately I've got to go afk for a while, but I should be able to fix this tonight (if no one beats me to it).
As a temporary workaround, you can use this version:
from jax.experimental.ode import _odeint_wrapper
def odeint(func, y0, t, *args, rtol=1.4e-8, atol=1.4e-8, mxstep=jnp.inf):
return _odeint_wrapper(func, rtol, atol, mxstep, y0, t, *args)
Thanks, @mattjj!
I didn't get to it last night, but #3587 should fix this. Thanks for catching it.
I'll do another pypi release after the fix goes in.
Just pushed jax==0.1.72 to pypi.
Here is a repro code, which works for previous version
Running the above script raises the error
TypeError: Primal inputs to reverse-mode differentiation must be of float or complex type, got type int32
. I tried to trace the error but got no hint whereint
variables are created. I think the issue happens after https://github.com/google/jax/pull/3562.