google-deepmind / dm-haiku

JAX-based neural network library
https://dm-haiku.readthedocs.io
Apache License 2.0
2.86k stars 232 forks source link

hk.switch does not work inside a hk.vmap function when hk.set_state is used #785

Open rsmath opened 2 months ago

rsmath commented 2 months ago

Hi,

I have been using Haiku (amazing tool, btw!) for about 8 months now. Up until now, I used to wrap my custom haiku modules around hk.transform. Inside my module, I vmapped a function (using hk.vmap) that contained a hk.switch statement (to evaluate a chosen branch function).

I recently moved to using stateful modules, which needs the hk.transform_with_state transform. I also need to keep track of a specific value over time in my machine learning model that is not to be updated by the optimizer. For this, I am using hk.set_state("name", val) to store it and access it later. However, as soon as I use any kind of set_state call anywhere in the model, the vmapped function fails with the error

ValueError: vmap has mapped output but out_axes is None

Is there any way to use hk.switch inside a hk.vmap function when hk.set_state is used in the module?

Thank you.

rsmath commented 2 months ago

Code

import haiku as hk
import numpy as np
import jax
import jax.numpy as jnp

def forward(input_):
    # hk.set_state("some_state_val", 2) # line that causes the problem

    def func(i, x):
        temp = hk.switch(i.squeeze(), applys, x)
        return temp

    func_vmap = hk.vmap(func, in_axes=(0, None), split_rng=True)

    applys = []
    for i in range(10):
        temp_apply = lambda x: x**2
        applys.append(temp_apply)

    pred = func_vmap(input_[1], input_[0])

    return pred

rng_key = jax.random.PRNGKey(4)

stateful_forward = hk.transform_with_state(forward)

data = jnp.asarray(np.random.rand(100, 30))
idx = jnp.asarray(np.random.randint(0, 10, (100, 1))).astype(jnp.int32)

inp = [data, idx]

init, apply = stateful_forward.init, stateful_forward.apply

init = jax.jit(init)
apply = jax.jit(apply)

params, state = init(rng_key, inp)

output, state = apply(params, state, rng_key, inp)