Open jjyyxx opened 1 year ago
Hi @jjyyxx , this is indeed confusing behaviour.
Changing this would be backwards incompatible with all existing usages of hk.vmap
so I think for now we will need to work around it.
hk.vmap
is mostly useful when you make use of hk.{g,s}et_state
, if you are developing a transformer then you are unlikely to be using these APIs and I think it would be safe to unconditionally use jax.vmap
which I believe does what you want (creates a new instance of the module on each call to the mapped function).
If you need to use hk.vmap
then there is a way to define a version of this that has the reuse semantics you want. Basically by wrapping the mapped function in a module (the only caveat is that this will add a prefix to the modules to disambiguate them):
def vmap_with_reuse(f, *, name: str | None = None):
f = hk.vmap(f, split_rng=(not hk.running_init()))
f = hk.to_module(f)
return lambda *a, **k: f(name=name)(*a, **k)
def f3(x):
def g(x):
return hk.Linear(2)(x)
x = vmap_with_reuse(g)(x)
x = vmap_with_reuse(g)(x)
return x
# w3: dict_keys(['g/linear', 'g_1/linear'])
Thanks for your suggestion! Indeed, I found that jax.vmap
works just OK before filing this issue. But I was worried about the documentation saying hk.vmap
is Equivalent to jax.vmap() with module parameters/state not mapped.
, which (from my perspective) implies that hk.vmap
handles both parameter
and state
. So, I kept using hk.vmap
at that time.
However, you mentioned that
hk.vmap
is mostly useful when you make use ofhk.{g,s}et_state
So, if only hk.get_parameter
is used, there is no need to use hk.vmap
? Also, what about the behavior of hk.next_rng_key
inside jax.vmap
?
I have to admit that I do not fully understand the necessity of
hk.vmap
instead ofjax.vmap
. Nevertheless, when I need to vmap something, I would usehk.vmap
whenever the inner function contains calls to haiku modules. This works OK, until I debug the bad performance of a transformer model. Things boils down to the following snippetIt turns out that when
g
is vmapped, modules created insideg
would reuse a previously created module. In some cases, errors would happen immediately due to incompatible shape, but in other cases (for me, transformer layers have quite consistent shapes), things went wrong silently.My question: Is this behavior intended? Could the documentation be improved on this topic? Or am I missing something?