google-deepmind / dm-haiku

JAX-based neural network library
https://dm-haiku.readthedocs.io
Apache License 2.0
2.86k stars 232 forks source link

Support for custom pytrees #32

Open awav opened 4 years ago

awav commented 4 years ago

Hello, haiku team! Thanks a lot for making awesome haiku.

I'm interested in sequential probabilistic models. Normally, parameters of probabilistic models are constrained. A simple example would be variance. It can only be positive. I gave an example and explanation of the constrained parameters in https://github.com/deepmind/dm-haiku/issues/16#issuecomment-602087358. The pytrees ideally fits into the described use case. The user can create its own differentiable "vectors" and I would expect haiku to support these custom structures out of the box. This would allow a user to get back actual structures from transformed functions for printing, debugging, and plotting purposes (the list can be enlarged with other examples from academic needs). Unfortunately, custom differentiable structures don't work at the moment.

Failing example

```python In [58]: class S(hk.Module): ...: def __init__(self, x, y): ...: super().__init__() ...: # These are parameters: ...: self.x = x ...: self.y = y ...: def __repr__(self): ...: return "RegisteredSpecial(x={}, y={})".format(self.x, self.y) ...: def S_flatten(v): ...: children = (v.x, v.y) ...: aux_data = None ...: return (children, aux_data) ...: def S_unflatten(aux_data, children): ...: return S(*children) ...: register_pytree_node(S, S_flatten, S_unflatten) ...: ...: ...: def function(s): ...: return np.sqrt(s.x**2 * s.y**2) ...: ...: def loss(x): ...: s = S(1.0, 2.0) ...: a = hk.get_parameter("free_parameter", shape=[], dtype=jnp.float32, init=jnp.zeros) ...: return jnp.sum([function(s) * a * x]) ...: ...: x = jnp.array([2.0]) ...: forward = hk.transform(loss) ...: key = jax.random.PRNGKey(42) ...: params = forward.init(key, x) In [59]: params Out[59]: frozendict({ '~': frozendict({'free_parameter': DeviceArray(0., dtype=float32)}), }) ```

Thanks

tomhennigan commented 4 years ago

Hi @awav, thanks for trying Haiku!

I think there are two Haiku assumptions that you are challenging here:

  1. Parameters are always jnp.ndarray instances (we assume this when checking the return value of get_parameter but otherwise this is not a hard requirement).
  2. Modules are always temporary objects that are deleted when transformed functions return (by returning a module from a function you violate this, although calling repr or getting non-computed properties should work). This is a fairly hard requirement, we work hard in Haiku to make sure when you use transform that the result is pure (wrt. Haiku API calls) and this would not be the case if modules existed outside transform (there would then need to be a global scope for them to find parameters/state).

I think we can make this work, concretely I would suggest:

  1. Use NamedTuple to define S (no need to register it as a custom pytree then 😄).
  2. Separate the data structure from the module (e.g. we have S and SModule).
  3. We use SModule to create S instances and use get_parameter to mark them as parameters.
  4. As a temporary workaround we make S have a shape property so it looks like a parameter (we could relax this in Haiku).

Putting that all together:

import jax
import jax.numpy as jnp
import haiku as hk
from typing import NamedTuple

class S(NamedTuple):
  x: jnp.ndarray
  y: jnp.ndarray

  @property
  def shape(self):
    # Hack to workaround the fact that `get_parameter` checks tensor shapes.
    return ()

class SModule(hk.Module):
  def __init__(self, x, y, name=None):
    super().__init__(name=name)
    self.s = hk.get_parameter("struct", (), None, init=lambda *_: S(x, y))

  def __call__(self, x, a):
    return jnp.sqrt(self.s.x ** 2 * self.s.y ** 2) * x * a

def loss(x):
  s = SModule(1.0, 2.0)
  a = hk.get_parameter("free", shape=(), dtype=jnp.float32, init=jnp.ones)
  y = s(x, a)
  return jnp.sum(y)

loss = hk.transform(loss)

x = jnp.array([2.0])
key = jax.random.PRNGKey(42)
params = loss.init(key, x)
jax.grad(loss.apply)(params, x)

Output:

frozendict({
  's_module': frozendict({
                'struct': S(x=DeviceArray(2., dtype=float32), y=DeviceArray(1., dtype=float32)),
              }),
  '~': frozendict({'free': DeviceArray(0., dtype=float32)}),
})

If this looks good then I'm happy to make a change to get_parameter to support parameters that are trees (e.g. we only check the shape if the result of get_parameter is an ndarray instance.

WDYT?

awav commented 4 years ago

@tomhennigan, for a very simple case, the namedtuple approach will work. However, the main challenge is the implementation of transformed parameters. The parameter with a constraint would look like this:

class Parameter:
  def __init__(self, init_constrained_value: jnp.ndarray, constraint: tfp.bijectors.Bijector):
    # NOTE: Compute gradients w.r.t. this unconstrained value!!!
    self._unconstrained_value = constraint.inverse(init_constrained_value)
    self._constraint = constraint

  # NOTE: convert the value in unconstrained space to the value in constrained space
  def constrained_value(self):
    return self._constraint.forward(self._unconstrained_value)

  def __call__(self):
    return self.constrained_value()

def loss(x):
  p = Parameter(1.0, tfp.bijector.Exp())
  return jnp.square(p())

def loss_complex(x):
  class ProbModel:
    def __init__(self):
       self.variance = Parameter(1.0, tfp.bijector.Exp())
    def __call__(self, x):
       pass
  m = ProbModel()
  return m(x)

After initialization, a researcher needs information about passed bijector for different reasons, that could be monitoring or debugging an algorithm. Does it make sense?

Also, I don't really like self.s = hk.get_parameter("struct", (), None, init=lambda *_: S(x, y)) line, that looks hacky and I would prefer to have a function for getting a structure, e.g. hk.get_structure("name", getter="").

tomhennigan commented 4 years ago

@sharadmv has done a lot of thinking about probabilistic programming in JAX (outside of Haiku) and might have some useful input for us here.

After initialization, a researcher needs information about passed bijector for different reasons, that could be monitoring or debugging an algorithm. Does it make sense?

Absolutely.

Also, I don't really like self.s = hk.get_parameter("struct", (), None, init=lambda *_: S(x, y)) line, that looks hacky and I would prefer to have a function for getting a structure, e.g. hk.get_structure("name", getter="").

Agreed that it is ugly looking, I like your suggestion, I think we should probably call this get_parameter_tree to make it clear that it is strongly related to get_parameter (e.g.s = hk.get_parameter_tree("s", init=lambda: S(a, b))). I'm happy to add that, will close out this issue with a commit later today.

Is there anything else in Haiku getting in your way for this type of research?

mattwescott commented 4 years ago

@tomhennigan your get_parameter_tree proposal would be useful to me. On a related note, for some parameter transformations it is useful to know the type of the corresponding module. Is this accessible in haiku without adding type information to module names?

tomhennigan commented 4 years ago

Hey @mattwescott and @awav , sorry for the delay implementing this. Before adding to core I want to think carefully about how it will interact with JAX transforms, especially when those transforms are used inside a haiku transformed function (e.g. via hk.jit).

For now you should be able to use this without needing changes in Haiku by adding the following utility function in your code and using it in your modules (it is slightly ugly since it adds a "Box" type around your type, but otherwise this should unblock you):

from typing import Any, NamedTuple

class Box(NamedTuple):
  value: Any
  shape = property(fget=lambda _: ())

def get_parameter_tree(name, init):
  return hk.get_parameter(name, [], init=lambda *_: Box(init())).value

You can use it as so:

>>> def f():
...   p = get_parameter_tree("w", lambda: (jnp.ones([]), jnp.zeros([])))
...   return p

>>> hk.transform(f, apply_rng=True).init(None)
frozendict({
  '~': frozendict({
         'w': Box(value=(DeviceArray(1., dtype=float32), DeviceArray(0., dtype=float32))),
       }),
})

On a related note, for some parameter transformations it is useful to know the type of the corresponding module. Is this accessible in haiku without adding type information to module names?

It isn't right now, the closest we have is hk.experimental.custom_creator which allows you to intercept parameter creation, one thing people have been using this for at DeepMind is to stash all the initializers for their parameters:

>>> inits = {}
>>> def creator(next_getter, name, shape, dtype, init):
...   inits[name] = init
...   return next_getter(name, shape, dtype, init)

>>> f = lambda: hk.nets.MLP([300, 100, 10])(jnp.ones([1, 1]))
>>> f = hk.transform(f, apply_rng=True)
>>> with hk.experimental.custom_creator(creator):
...   f.init(jax.random.PRNGKey(42))
>>> inits
{'mlp/~/linear_0/w': <haiku._src.initializers.TruncatedNormal at 0x7f28476df5f8>,
 'mlp/~/linear_0/b': <function jax.numpy.lax_numpy.zeros>,
 'mlp/~/linear_1/w': <haiku._src.initializers.TruncatedNormal at 0x7f283125b048>,
 'mlp/~/linear_1/b': <function jax.numpy.lax_numpy.zeros>,
 'mlp/~/linear_2/w': <haiku._src.initializers.TruncatedNormal at 0x7f28476df358>,
 'mlp/~/linear_2/b': <function jax.numpy.lax_numpy.zeros>}

I could imagine extending this custom getter to also pass the module as well as the init function, then you could keep a copy of type(module). Would that be useful for you?

mattwescott commented 4 years ago

@tomhennigan thanks for the examples.

I could imagine extending this custom getter to also pass the module as well as the init function

This would be great, so much cleaner!

mattwescott commented 4 years ago

@tomhennigan

Would it be impractical to instead intercept module creation? With a mapping from module names to types, could use tree.flatten_with_path_up_to for straightforward type-dependent transformations of the parameter tree.

Either approach would likely be sufficient for me to adopt Haiku.

tomhennigan commented 4 years ago

Support for extracting module info in a creator has landed 😄 Here's an example colab using it to extract all info to a dict outside the function: https://colab.research.google.com/drive/1tt9ifYFsxvSSXaFAz_Oq59Im8QY4S16o

Using it inside a transformed function is documented here: https://dm-haiku.readthedocs.io/en/latest/api.html#haiku.experimental.custom_creator

awav commented 4 years ago

@tomhennigan, I found out that flax has a support for dataclasses and it has all what I needed (a big part of it). I haven't tried it with haiku, but I believe it should work with haiku out of the box. JAX must work with dataclass implicitly, but looks like it cannot, without flax at least. Do you have plans for doing a similar thing?

bionicles commented 4 years ago

whoa, that struct.dataclass is cool, and would solve headaches of passing modules to functions and getting not a JAX type errors