google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.07k stars 642 forks source link

Linen: Consider raising an error when reading variables from submodule before initialization? #513

Open avital opened 4 years ago

avital commented 4 years ago

Within a module that uses shape-inference (as most of the built-in Linen modules do), this code is fine:

class MyModule(nn.Module):
  def __call__(self, x):
    conv = nn.Conv(features=3)
    y = conv(x)
    params = conv.params()

But if you instead do:

class MyModule(nn.Module):
  def __call__(self, x):
    conv = nn.Conv(features=3)
    params = conv.params()
    y = conv(x)

Then I believe you get an empty params dict (as the parameters of conv are only initialized once the input shape is known)

Seems like users may be surprised about this, so instead we could just raise an error if the variables are empty when reading them clarifying what is happening, to guide users to the right direction.

avital commented 4 years ago

(It's also worth considering writing a "design note" explaining the design of shape inference in Linen)

freddyaboulton commented 3 years ago

@avital I'm not sure if we should raise always an error if the variables are empty when being read. If the module does not have any variables then the dict will always be empty and maybe that's what's intended. Like in this (maybe contrived) example:

from flax import linen as nn
from jax import random
import jax.numpy as jnp

class MyRelu(nn.Module):

    @nn.compact
    def __call__(self, x):
        return nn.relu(x)

class UsesMyRelu(nn.Module):
    @nn.compact
    def __call__(self, x):
        relu = MyRelu()
        variables = relu.variables
        y = relu(x)
        variables_after = relu.variables
        print(f"Variables Before: {variables}")
        print(f"Variables After: {variables_after}")
        return y

x = jnp.array([[-1, 5],
               [3, -5]])
rng = dict(params=random.PRNGKey(0))

relu_mod = UsesMyRelu()
output, params = relu_mod.init_with_output(rng, x)
Variables Before: FrozenDict({})
Variables After: FrozenDict({})
avital commented 3 years ago

On the one hand, if a user tries to access the variables of MyRelu, then at the very least that's probably not what they wanted to do. So then maybe an error makes sense.

On the other hand, maybe some folks are doing this programatically, e.g. a Sequential wrapper where some submodules have variables and some don't. You could imagine wanting to for examples count the total number of parameters.

On the (third?) hand, if you're going to use relu (or any function of object without varaibles), why would you wrap it in a Module to begin with? Functions work fine. (Or if you want multiple methods, normal objects work fine).

To summarize, your argument is interesting but I still don't understand the full usecase where it would be a problem.

freddyaboulton commented 3 years ago

Thanks for explaining that! I agree that wrapping functions without variables in a Module doesn't make much sense.

jheek commented 3 years ago

I don't think checking for empty variables dicts as a safety check is super useful. The assertion doesn't guarantee that all variables exists only at least one so users still need to understand laziness. Also it only works during init because during apply self.variables always returns all the paramaters.

The downside of this making this an error is that if someone writes a generic piece of code that relies on module.variables it will now need a try...except statement to support variable free modules.

avital commented 3 years ago

to support variable free modules.

why would someone make variable free modules rather than just use functions?

jheek commented 3 years ago

I've seen it a few times already in bigger projects. Sometimes you configure a module using hyper params and one of the options is a paramater free one like using a sine positional embedding vs learned positional embeddings.

avital commented 3 years ago

Sometimes you configure a module using hyper params and one of the options is a paramater free one like using a sine positional embedding vs learned positional embeddings.

Ah that's interesting and I could totally see this.

I wonder if we can think about some other mechanism... I am not sure I actually like this, but for example a module could somehow mark when its variable should be fully populated. Like, if there was a call self.variables_finalized() that you could call either inside a nn.compact wrapper method or within setup. And then if you try calling submodule.variables() before that you get an error. (I don't actually like this proposal but I /think/(?) it would technically work at the expense of making it a bit more verbose to define modules).

I am not proposing any action item here, I just thought I'd share to see if it inspires you or anyone else for some proposal.

jheek commented 3 years ago

Still there might be cases where you want to access variables even though they aren't final. Let's say I want to access a kernel but I didn't necessarily define the bias yet. A simple solution would be a helpful reminder to the user. For example variables could be a subclass of FrozenDict that returns a helpful error on KeyError like "Perhaps the variable you are trying to access is not initialized yet..."

cgarciae commented 2 years ago

Based on @jheek's comment in #638, how about we do the following:

  1. Create a simple ErrorHandlers dataclass, e.g:
@dataclass(frozen=True)
class ErrorHandlers:
  missing_key: Optional[Callable[[str], Exception]] = None
  set_item: Optional[Callable[[str, Any], Exception]] = None
  ...
  1. Update FrozenDict to have a _error_handlers field and add logic to use the handlers in a couple of places.
  2. Add a set_error_handlers method to FrozenDict as proposed by @jheek which takes in the handlers, only modification I would propose is that it returns a new FrozenDict to keep the API immutable.
xs = FrozenDict(variables)
xs = xs.set_error_handlers(missing_key=lambda key: Error(...), set_item=lambda key, value: Error(...)

We could start with a single handler that solve this specific issue.