Open rsmath opened 4 months ago
Code
import haiku as hk
import numpy as np
import jax
import jax.numpy as jnp
def forward(input_):
# hk.set_state("some_state_val", 2) # line that causes the problem
def func(i, x):
temp = hk.switch(i.squeeze(), applys, x)
return temp
func_vmap = hk.vmap(func, in_axes=(0, None), split_rng=True)
applys = []
for i in range(10):
temp_apply = lambda x: x**2
applys.append(temp_apply)
pred = func_vmap(input_[1], input_[0])
return pred
rng_key = jax.random.PRNGKey(4)
stateful_forward = hk.transform_with_state(forward)
data = jnp.asarray(np.random.rand(100, 30))
idx = jnp.asarray(np.random.randint(0, 10, (100, 1))).astype(jnp.int32)
inp = [data, idx]
init, apply = stateful_forward.init, stateful_forward.apply
init = jax.jit(init)
apply = jax.jit(apply)
params, state = init(rng_key, inp)
output, state = apply(params, state, rng_key, inp)
Hi,
I have been using Haiku (amazing tool, btw!) for about 8 months now. Up until now, I used to wrap my custom haiku modules around
hk.transform
. Inside my module, I vmapped a function (usinghk.vmap
) that contained ahk.switch
statement (to evaluate a chosen branch function).I recently moved to using stateful modules, which needs the
hk.transform_with_state
transform. I also need to keep track of a specific value over time in my machine learning model that is not to be updated by the optimizer. For this, I am usinghk.set_state("name", val)
to store it and access it later. However, as soon as I use any kind ofset_state
call anywhere in the model, the vmapped function fails with the errorIs there any way to use
hk.switch
inside ahk.vmap
function whenhk.set_state
is used in the module?Thank you.