google-deepmind / dm-haiku

JAX-based neural network library
https://dm-haiku.readthedocs.io
Apache License 2.0
2.89k stars 233 forks source link

Allow creating module instances outside hk.transform #16

Open gehring opened 4 years ago

gehring commented 4 years ago

This is as much a question as it is a feature request. What is the reasoning for not allowing a module instance from being created (but not used) outside hk.transform? I took a look at hk.Module and ModuleMetaClass but I feared my soul would get harvested by the dark forbidden magic involved before I could identify all the API features it permits.

For example, I would have expected this to be possible:

linear = hk.Linear(10)  # currently not allowed

def forward(x):
  return linear(x)

model = hk.transform(forward)

Concretely, I'm curious to know what would have to be sacrificed (if anything) to support this kind of usage? Is it meant to prevent a module instance from being used in two different functions wrapped by two different hk.transform calls?

I wouldn't be surprised if I were missing some nasty side effect if you were to allow module creation outside of hk.transform, but, if not, I think it would be more intuitive to allow this kind of usage.

trevorcai commented 4 years ago

The primary reason we don't currently allow this is that hk.Module objects have unique names (within their hk.transform), accessible via self.name or self.module_name. These names route parameters & state into the right place for hk.get_parameter calls, and are given to the module at construction time (in super().__init__(name=name)).

Uniquifying names requires us to track some state about the names that have already been created. We've made an attempt towards allowing the construction of modules that don't use hk.get_parameter and the other provided monads in their given constructors, but we haven't managed to do this without introducing persistent global state.

There are other solutions that we could try here! One idea is to late-bind names inside hk.transform, but we haven't prioritized this line of work.

Does that make sense? WDYT?

gehring commented 4 years ago

That all makes sense, thanks for the explanation!

One idea is to late-bind names inside hk.transform, but we haven't prioritized this line of work.

I think that would be great if that could be implemented without adding much complexity but I completely agree that it doesn't feel like a priority. I think the current API is just as powerful without this feature once you get use to it (which in my personal experience took me about 3 "oupsies" and cost me no more than 5 min in refactoring).

gehring commented 4 years ago

I'm not sure if you want to keep this issue open for feature request tracking purposes but, if not, feel free to close it.

trevorcai commented 4 years ago

That's good to hear - that's been my experience as well :) I'll leave this issue open to track this FR.

awav commented 4 years ago

Hello @trevorcai, @tomhennigan. I like a lot out of the box solutions, but I struggle with extending haiku at the moment. I need constrained parameters like variance (only positive) for Gaussian distributions. The parameter can be represented as a composition constraint: unconstrained_parameter -> bijector.forward(parameter), in my code it is a property of the module. A dictionary with a set of parameters contains only unconstrained version, but for tracking and model printing we need constrained values and there is no way to get it because the model instance is hidden in the function.

class Parameter():
  def __init__(self, init_value: float, name: Text):
    super().__init__(name="parameter")
    self._name = name
    self._init = hk.initializer.Constant(jnp.log(init_value))

  def __call__(self):
    return jnp.exp(hk.get_parameter(f"unconstrained_{self._name}", shape=[], init=self._init))

class Model(hk.Module):
  def __init__(self, init_variance: float, name: Text):
    super().__init__(name)
    self._variance = Parameter(init_variance, "variance")

  @property
  def variance(self):
    return self._variance()

  def __call__(self, x: jnp.array) -> jnp.array:
    return jnp.sum(self.variance * x)

As you can see, a variance value in a parameter dictionary will not have much meaning without information about a transformation that a model uses (could be exp, softplus or another positive bijector).

1. One solution could be to return a model with transformed functions.

def forward_fn(x):
  m = Model(0.1)
  hk.link(m)  
  return m(x)

forward = hk.transoform(forward_fn)
model = forward.linked_objects  # get access to read only object

2. Another possible (?) solution could be making hk.transform a context manager

class Holder(hk.ModuleHolder):
  @hk.transform
  def forward(self, x):
    self.model = Model(0.1)
    return self.model(x)

forward = Holder().forward()

PS: for me, it is a very important issue and a deciding factor on how I'm going to use the library.