jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.08k stars 2.75k forks source link

grad of vmap of odeint with rng-dependent dynamics gives tracer error #2797

Closed duvenaud closed 4 years ago

duvenaud commented 4 years ago

I'm trying to code up FFJORD in a Jax-y style. This means I need to compute the gradient of the batched training loss, which itself is defined by a call to odeint, whose dynamics are defined by a random number. It seems that somehow all of these together are causing a tracer level error. I updated jax and jaxlib to 0.1.64.

If I remove of any one of vmap, grad, odeint, or the rng, then it works.

I've pared down a minimal working example: (Edited to be even more minimal)

from jax.api import grad, vmap
from jax import random
from jax.experimental.ode import odeint
import jax.numpy as np

def ffjord_log_density(params, x, D, rng):

    eps = random.normal(rng, (D,))

    def aug_dynamics(aug_state, t, args):
        dlogp_dt = 1.0
        dz_dt = eps
        return np.hstack([dz_dt, dlogp_dt])

    init_state = np.hstack([x, 0.])
    aug_out = odeint(aug_dynamics, init_state, np.array([0., 1.]), params)[1]
    return aug_out[-1]

def batch_likelihood(params, data, rng):
    N, D = np.shape(data)
    rngs = random.split(rng, N)
    batch_density = vmap(ffjord_log_density, in_axes=(None, 0, None, 0))
    return np.mean(batch_density(params, data, D, rngs))

if __name__ == "__main__":

    rng = random.PRNGKey(0)
    data = random.normal(rng, (10, 2))
    D = data.shape[1]

    def objective(params):
        return -batch_likelihood(params, data, rng)

    init_params = np.array([1.0])

    print("Evaluating objective...")
    print(objective(init_params))  # Works

    print("Evaluating gradient of objective...")
    print(grad(objective)(init_params))  # Fails

Here's the error:

jax.core.UnexpectedTracerError: Encountered an unexpected tracer. Perhaps this tracer escaped through global state from a previously traced function. The functions being transformed should not save traced values to global state. Details: Tracer from a higher level: Traced<ShapedArray(float32[3])>with<BatchTrace(level=2/1)> with val = DeviceArray([[ 0.15854332, 1.0789754 , 1. ], [-0.52583337, 0.13718702, 1. ], [ 0.9702075 , -0.12123355, 1. ], [ 0.77035123, 0.77410245, 1. ], [ 2.5521684 , 1.0251945 , 1. ], [-0.42066336, -0.7609343 , 1. ], [-0.16220848, 0.12555611, 1. ], [ 0.12595005, -0.44731647, 1. ], [-1.2705209 , -0.13543864, 1. ], [ 1.8663106 , 0.56866485, 1. ]], dtype=float32) batch_dim = 0 in trace JaxprTrace(level=2/1).

mattjj commented 4 years ago

I think the problem is that aug_dynamics closes over the vmap tracer on eps. Can we promote that to be an explicit argument instead of having the dynamics function close over it?

(I'd like to promote odeint to handle all kinds of closures in its function-valued arguments, like the lax control flow primitives, but right now it can't...)

mattjj commented 4 years ago

This seems not to crash:

def ffjord_log_density(params, x, D, rng):

    eps = random.normal(rng, (D,))

    def aug_dynamics(aug_state, t, eps, args):
        dlogp_dt = 1.0
        dz_dt = eps
        return np.hstack([dz_dt, dlogp_dt])

    init_state = np.hstack([x, 0.])
    aug_out = odeint(aug_dynamics, init_state, np.array([0., 1.]), eps, params)[1]
    return aug_out[-1]

Bad error message though...

shoyer commented 4 years ago

(I'd like to promote odeint to handle all kinds of closures in its function-valued arguments, like the lax control flow primitives, but right now it can't...)

Couldn't we closure convert via tracing to a JAXpr, like in custom_root? It seems like that shouldn't be too hard...

mattjj commented 4 years ago

Yes, that will work, and is basically what we have to do. But it won't scale to more primitives, and it'd be a huge burden for folks like @duvenaud to learn about just to add things like odeint VJPs. We should factor this into utility functions so a user can just designate which arguments are function-valued, and we handle the jaxpr tracing stuff behind the scenes.

duvenaud commented 4 years ago

D'oh, thanks for the fix. In my defense, I had also tried moving eps = random.normal(rng, (D,)) inside the dynamics, but now I realize that didn't work because it wasn't closing over rng. There was also no gradient flowing through either so I thought it'd be fine to ignore that dependency.

EyalRozenberg1 commented 2 years ago

Hey David @duvenaud. Have you implemented a JAX version of FFJORD?

duvenaud commented 2 years ago

Yes, I made a toy demo, but @jacobjinkelly made a nice fleshed-out version: https://github.com/jacobjinkelly/easy-neural-ode/blob/master/ffjord_mnist.py

EyalRozenberg1 commented 2 years ago

Wonderful! Thanks @duvenaud @jacobjinkelly