google-deepmind / dm-haiku

JAX-based neural network library
https://dm-haiku.readthedocs.io
Apache License 2.0
2.85k stars 231 forks source link

Add more documentation on calling JAX transforms inside a Haiku module. #176

Open LenaMartens opened 3 years ago

LenaMartens commented 3 years ago

We have https://dm-haiku.readthedocs.io/en/latest/notebooks/transforms.html, but I think we can add some more context on what error you get (an UnexpectedTracerError) and some alternatives to hk.lift (eg. if the only side-effect is parameter creation, switching on hk.running_init()).

jatentaki commented 3 years ago

Is it possible to capture the leaked tracer exceptions arising inside hk.transform and wrap them with an explanation that this could be due to using jax transforms? Even if that's not always the real issue, it would point the user to some documentation which may help (rather than try debugging the jax code itself).

jatentaki commented 3 years ago

(reposting a comment I originally made in the jax repo)

@LenaMartens sorry, I found that the issue persists. First, my code above was a bit buggy (wrong assertions etc) so I modified it slightly and applied your fix (below). Additionally, I had to subselect the last output of odeint along the timestep axis, since otherwise it differs from odefunc evaluation (did you modify that in your example?). It runs and seems to produce the right results, which is why I initially said the issue is solved, but it turns out that running with JAX_CHECK_TRACER_LEAKS=1 still reports a leak. In my actual codebase those are reported at this line

    def dy_and_div(self, ctx, y):
        e = jax.random.normal(hk.next_rng_key(), y.shape)

My suspicion would be that the tracer "redirected" with hk.running_init() fails to notice that an actual execution would ask for many more random keys than the init run. Is this related?

By the way: I'm not sure if I'm following the docs correctly. Shouldn't I have also replaced the jax.vmap call with hk.vmap? Also, I see that there are haiku versions of grad and value_and_grad. In my code I'm using vjp (inside traced code), is it somehow a different case than the former and doesn't require special treatment, or is it just that haiku hasn't yet implemented the wrapper necessary for my code?

import jax
import jax.numpy as jnp
import haiku as hk
from jax.experimental.ode import odeint

def ODEnet():
    def ODEnet_inner(y, context):
        return hk.Linear(y.shape[-1])(y)

    return ODEnet_inner

class ODEfunc(hk.Module):
    def __init__(self, diffeq):
        super(ODEfunc, self).__init__()
        self.diffeq = diffeq

    def __call__(self, states, t):
        assert t.shape == (), t.shape
        y, _logpx = states

        dy, div = self.dy_and_div(t[None], y)
        return dy, -div

    def dy_and_div(self, ctx, y):
        e = jax.random.normal(hk.next_rng_key(), y.shape)
        dy, dy_vjp_fn = jax.vjp(lambda y: self.diffeq(y, ctx), y)
        vjp, = dy_vjp_fn(e)
        div = jnp.dot(vjp, e)

        return dy, div

class CNF(hk.Module):
    def __init__(self, odefunc):
        super(CNF, self).__init__()

        self.odefunc = odefunc
        self.T = jnp.array(1.)

    def __call__(self, x, logpx):
        integration_times = jnp.array([0., self.T])

        v_call = jax.vmap(
            lambda x, lpx, t: self._call_inner(x, lpx, t),
            in_axes=(0, 0, None),
            out_axes=0,
        )

        z_t, logpz_t = v_call(x, logpx, integration_times)

        if logpx is None:
            return z_t
        else:
            return z_t, logpz_t

    def _call_inner(self, x, logpx, integration_times):
        assert logpx.shape == (*x.shape[:-1], 1)

        states = (x, logpx)

        if hk.running_init():
            state_t = self.odefunc(states, integration_times[0])
            z_t, logpz_t = state_t[:2]
            logpz_t = logpz_t[..., None]
        else:
            state_t = odeint(
                self.odefunc, 
                states,
                integration_times,
            )

            z_t, logpz_t = state_t[:2]
            z_t = z_t[..., -1, :]
            logpz_t = logpz_t[..., -1, :]

        return z_t, logpz_t

import unittest

class Tests(unittest.TestCase):
    def test_cnf_unconditional(self):
        import numpy as np

        N, zdim, context_dim = 128, 32, 48

        def create(*args, **kwargs):
            return CNF(
                ODEfunc(ODEnet()),
            )(*args, **kwargs)

        cnf_fn = hk.transform_with_state(create)
        rng = jax.random.PRNGKey(42)

        x = np.random.randn(N, 3)
        logpx = np.zeros((N, 1))

        data = dict(x=x, logpx=logpx)
        params, state = cnf_fn.init(rng, **data)
        (z_t, logpz_t), state = cnf_fn.apply(params, state, rng, **data)

if __name__ == '__main__':
    unittest.main()
LenaMartens commented 3 years ago

Sorry, I took a short-cut in my previous answer for brevity, but I omitted some important details. This is partly why better documentation would be great!

The fundamental issue here is that Haiku modules are side-effecting: they modify parameter and RNG state, while these are not explicit function arguments and are not explicitly returned from the module functions. It's why we have hk.transform, which serves as a functionalization point: all module state is now being made explicit arguments to the init/apply functions, and these functions can be used with JAX transforms. So the issue here is that we have this functional boundary at the hk.transform, and inside of that functional boundary we can't safely use JAX transforms. The leaked tracers originate in these side-effecting operations on the Haiku state.

My previous solution of using hk.running_init() only works if the only side-effect is the creation of parameters, but you run into trouble again if your module advances the RNG state or changes state in any other way.

We provide some utilities to make it easier to nest JAX transforms inside of a hk.transform, like hk.vmap. As you noticed we don't wrap all functions, partly because it's difficult to choose the correct defaults. For example for hk.vmap, we always broadcast your Haiku state across the vmap (so you get the same parameters/RNG on every batch, effectively having in_axes=None for the Haiku state), but you might want to map over your parameter state instead.

I think the easiest solution here is to try-out hk.experimental.lift, which allows you to nest hk.transforms, effectively providing you with a way to functionalize a side-effecting Haiku function inside of a Haiku module. You can then use that functionally pure function with JAX transforms freely, eg. like below (to replace the hk.running_init() solution):

transformed_layer = hk.transform(self.odefunc) # nested transform
# lift registers the inner_params in the outer Haiku module's parameters
inner_params = hk.experimental.lift(transformed_layer.init)(hk.next_rng_key(), states, integration_times) 
# switching argument order around because odeint expects states and times to be the first args
arg_order_switch = lambda y, t, params, rng: transformed_layer.apply(params, rng, y, t) 
state_t = odeint(
    arg_order_switch,             
    states,
    integration_times,
    inner_params, 
    hk.next_rng_key()
)

The only annoyance is that you have multiple JAX transforms, and I think you might need to do this exercise ^ for every Haiku function you're transforming (ie. at the point you're using jax.vjp and jax.vmap, or in the case of vmap you could use hk.vmap instead)

Apologies for this sharp edge! Does what I wrote here make sense? Could you try-out the lift option?