module.params_dict() can behave in surprising ways:
def f(x):
mod = hk.Linear(8)
print(mod.params_dict()) # empty during init, full during apply
sequential = hk.Sequential([mod])
print(sequential.params_dict()) # always empty
out = sequential(x)
print(sequential.params_dict()) # no longer empty
return out
net = hk.transform(f)
p = net.init(jax.random.PRNGKey(428), np.zeros((2, 3)))
net.apply(p, np.zeros((2, 3)))
Prints:
{}
{}
{...}
{...}
{}
{...}
We should clean up & clearly define the desired semantics of params_dict().
module.params_dict()
can behave in surprising ways:Prints:
We should clean up & clearly define the desired semantics of
params_dict()
.