jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.48k stars 2.8k forks source link

ode is not working in jax 0.1.70 #3584

Closed fehiepsi closed 4 years ago

fehiepsi commented 4 years ago

Here is a repro code, which works for previous version

import jax
import jax.numpy as jnp
from jax.experimental.ode import odeint

def dz_dt(z, t, theta):
    """ Lotka–Volterra equations. """
    u = z[0]
    v = z[1]
    alpha, beta, gamma, delta = theta[0], theta[1], theta[2], theta[3]
    du_dt = (alpha - beta * v) * u
    dv_dt = (-gamma + delta * u) * v
    return jnp.stack([du_dt, dv_dt])

def f(z):
    y = odeint(dz_dt, z, jnp.arange(10.), jnp.ones(4))
    return jnp.sum(y)

jax.grad(f)(jnp.ones(2))

Running the above script raises the error TypeError: Primal inputs to reverse-mode differentiation must be of float or complex type, got type int32. I tried to trace the error but got no hint where int variables are created. I think the issue happens after https://github.com/google/jax/pull/3562.

fehiepsi commented 4 years ago

A simpler repro code

def dz_dt(z, t):
    return jnp.stack([z[0], z[1]])

def f(z):
    y = odeint(dz_dt, z, jnp.arange(10.))
    return jnp.sum(y)

jax.grad(f)(jnp.ones(2))

It seems to me that the indices 0, 1 cause the issue.

mattjj commented 4 years ago

Ah, this is indeed because of #3562. Thanks for catching it!

mattjj commented 4 years ago

Unfortunately I've got to go afk for a while, but I should be able to fix this tonight (if no one beats me to it).

mattjj commented 4 years ago

As a temporary workaround, you can use this version:

from jax.experimental.ode import _odeint_wrapper

def odeint(func, y0, t, *args, rtol=1.4e-8, atol=1.4e-8, mxstep=jnp.inf):
return _odeint_wrapper(func, rtol, atol, mxstep, y0, t, *args)
fehiepsi commented 4 years ago

Thanks, @mattjj!

mattjj commented 4 years ago

I didn't get to it last night, but #3587 should fix this. Thanks for catching it.

I'll do another pypi release after the fix goes in.

mattjj commented 4 years ago

Just pushed jax==0.1.72 to pypi.