google-deepmind / dm-haiku

JAX-based neural network library
https://dm-haiku.readthedocs.io
Apache License 2.0
2.84k stars 230 forks source link

Improve `params_dict()` support #23

Open trevorcai opened 4 years ago

trevorcai commented 4 years ago

module.params_dict() can behave in surprising ways:

def f(x):
  mod = hk.Linear(8)
  print(mod.params_dict())  # empty during init, full during apply
  sequential = hk.Sequential([mod])
  print(sequential.params_dict())  # always empty
  out = sequential(x)
  print(sequential.params_dict())  # no longer empty
  return out

net = hk.transform(f)
p = net.init(jax.random.PRNGKey(428), np.zeros((2, 3)))
net.apply(p, np.zeros((2, 3)))

Prints:

{}
{}
{...}
{...}
{}
{...}

We should clean up & clearly define the desired semantics of params_dict().

girving commented 4 years ago

One specific desiderata: any call to params_dict should either throw an exception or return the same thing regardless of when it's called.