Open avital opened 4 years ago
(It's also worth considering writing a "design note" explaining the design of shape inference in Linen)
@avital I'm not sure if we should raise always an error if the variables are empty when being read. If the module does not have any variables then the dict will always be empty and maybe that's what's intended. Like in this (maybe contrived) example:
from flax import linen as nn
from jax import random
import jax.numpy as jnp
class MyRelu(nn.Module):
@nn.compact
def __call__(self, x):
return nn.relu(x)
class UsesMyRelu(nn.Module):
@nn.compact
def __call__(self, x):
relu = MyRelu()
variables = relu.variables
y = relu(x)
variables_after = relu.variables
print(f"Variables Before: {variables}")
print(f"Variables After: {variables_after}")
return y
x = jnp.array([[-1, 5],
[3, -5]])
rng = dict(params=random.PRNGKey(0))
relu_mod = UsesMyRelu()
output, params = relu_mod.init_with_output(rng, x)
Variables Before: FrozenDict({})
Variables After: FrozenDict({})
On the one hand, if a user tries to access the variables of MyRelu
, then at the very least that's probably not what they wanted to do. So then maybe an error makes sense.
On the other hand, maybe some folks are doing this programatically, e.g. a Sequential
wrapper where some submodules have variables and some don't. You could imagine wanting to for examples count the total number of parameters.
On the (third?) hand, if you're going to use relu (or any function of object without varaibles), why would you wrap it in a Module to begin with? Functions work fine. (Or if you want multiple methods, normal objects work fine).
To summarize, your argument is interesting but I still don't understand the full usecase where it would be a problem.
Thanks for explaining that! I agree that wrapping functions without variables in a Module doesn't make much sense.
I don't think checking for empty variables dicts as a safety check is super useful. The assertion doesn't guarantee that all variables exists only at least one so users still need to understand laziness. Also it only works during init because during apply self.variables
always returns all the paramaters.
The downside of this making this an error is that if someone writes a generic piece of code that relies on module.variables
it will now need a try...except
statement to support variable free modules.
to support variable free modules.
why would someone make variable free modules rather than just use functions?
I've seen it a few times already in bigger projects. Sometimes you configure a module using hyper params and one of the options is a paramater free one like using a sine positional embedding vs learned positional embeddings.
Sometimes you configure a module using hyper params and one of the options is a paramater free one like using a sine positional embedding vs learned positional embeddings.
Ah that's interesting and I could totally see this.
I wonder if we can think about some other mechanism... I am not sure I actually like this, but for example a module could somehow mark when its variable should be fully populated. Like, if there was a call self.variables_finalized()
that you could call either inside a nn.compact
wrapper method or within setup
. And then if you try calling submodule.variables()
before that you get an error. (I don't actually like this proposal but I /think/(?) it would technically work at the expense of making it a bit more verbose to define modules).
I am not proposing any action item here, I just thought I'd share to see if it inspires you or anyone else for some proposal.
Still there might be cases where you want to access variables even though they aren't final. Let's say I want to access a kernel but I didn't necessarily define the bias yet. A simple solution would be a helpful reminder to the user. For example variables could be a subclass of FrozenDict that returns a helpful error on KeyError like "Perhaps the variable you are trying to access is not initialized yet..."
Based on @jheek's comment in #638, how about we do the following:
ErrorHandlers
dataclass, e.g:@dataclass(frozen=True)
class ErrorHandlers:
missing_key: Optional[Callable[[str], Exception]] = None
set_item: Optional[Callable[[str, Any], Exception]] = None
...
FrozenDict
to have a _error_handlers
field and add logic to use the handlers in a couple of places.set_error_handlers
method to FrozenDict
as proposed by @jheek which takes in the handlers, only modification I would propose is that it returns a new FrozenDict
to keep the API immutable.xs = FrozenDict(variables)
xs = xs.set_error_handlers(missing_key=lambda key: Error(...), set_item=lambda key, value: Error(...)
We could start with a single handler that solve this specific issue.
Within a module that uses shape-inference (as most of the built-in Linen modules do), this code is fine:
But if you instead do:
Then I believe you get an empty
params
dict (as the parameters ofconv
are only initialized once the input shape is known)Seems like users may be surprised about this, so instead we could just raise an error if the variables are empty when reading them clarifying what is happening, to guide users to the right direction.