google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
5.98k stars 632 forks source link

Multiple Inheritance -> doesn't recognize as Module throws ValueError: parent must be None, Module or Scope #1409

Open avital opened 3 years ago

avital commented 3 years ago

Discussed in https://github.com/google/flax/discussions/1390

Originally posted by **SauravMaheshkar** June 26, 2021 I'm working on a Flax implementation for [ProteinBERT: A universal deep-learning model of protein sequence and function](https://www.biorxiv.org/content/10.1101/2021.05.24.445464v1). My work so far is in [SauravMaheshkar/ProteinBERT](https://github.com/SauravMaheshkar/ProteinBERT). I've made a simple `test.py` to check instantiation using the `.init()` function. My test script is as follows : ``` from proteinbert import ProteinBERT import jax from jax import random def test(): seq = jax.random.randint( key=random.PRNGKey(0), minval=0, maxval=21, shape=(2, 2048) ) annotation = jax.random.randint( key=random.PRNGKey(0), minval=0, maxval=1, shape=(2, 8943) ) init_rngs = {"params": random.PRNGKey(0), "layers": random.PRNGKey(1)} ProteinBERT().init(init_rngs, seq, annotation) if __name__ == "__main__": test() ``` And I've been getting this error message
Error Message ``` Traceback (most recent call last): File "/Users/sauravmaheshkar/github/protein_bert/test.py", line 21, in test() File "/Users/sauravmaheshkar/github/protein_bert/test.py", line 17, in test ProteinBERT().init(init_rngs, seq, annotation) File "/Users/sauravmaheshkar/opt/anaconda3/envs/proteinbert/lib/python3.7/site-packages/flax/linen/module.py", line 1000, in init method=method, mutable=mutable, **kwargs) File "/Users/sauravmaheshkar/opt/anaconda3/envs/proteinbert/lib/python3.7/site-packages/flax/linen/module.py", line 969, in init_with_output {}, *args, rngs=rngs, method=method, mutable=mutable, **kwargs) File "/Users/sauravmaheshkar/opt/anaconda3/envs/proteinbert/lib/python3.7/site-packages/flax/linen/module.py", line 939, in apply )(variables, *args, **kwargs, rngs=rngs) File "/Users/sauravmaheshkar/opt/anaconda3/envs/proteinbert/lib/python3.7/site-packages/flax/core/scope.py", line 687, in wrapper y = fn(root, *args, **kwargs) File "/Users/sauravmaheshkar/opt/anaconda3/envs/proteinbert/lib/python3.7/site-packages/flax/linen/module.py", line 1178, in scope_fn return fn(module.clone(parent=scope), *args, **kwargs) File "/Users/sauravmaheshkar/opt/anaconda3/envs/proteinbert/lib/python3.7/site-packages/flax/linen/module.py", line 266, in wrapped_module_method self._try_setup() File "/Users/sauravmaheshkar/opt/anaconda3/envs/proteinbert/lib/python3.7/site-packages/flax/linen/module.py", line 679, in _try_setup self.setup() File "/Users/sauravmaheshkar/opt/anaconda3/envs/proteinbert/lib/python3.7/site-packages/flax/linen/module.py", line 275, in wrapped_module_method y = fun(self, *args, **kwargs) File "/Users/sauravmaheshkar/github/protein_bert/proteinbert/model.py", line 82, in setup Reduce("b n d -> b d", "mean"), File "", line 5, in __init__ File "/Users/sauravmaheshkar/opt/anaconda3/envs/proteinbert/lib/python3.7/site-packages/flax/linen/module.py", line 599, in __post_init__ raise ValueError("parent must be None, Module or Scope") ValueError: parent must be None, Module or Scope ```
The problem lies in the `Reduce` defined in [proteinbert/utils.py](https://github.com/SauravMaheshkar/ProteinBERT/blob/jaxpackage/proteinbert/utils.py) class which is defined as follows: ``` class Reduce(ReduceMixin, nn.Module): """ Flax Module to act as a Reduce layer (from einops) """ def __call__(self, input): return self._apply_recipe(input) ``` The idea is to create a `Reduce` layer/Module for flax which performs the `reduce` operation from `einops`. Although the module inherits from `flax.linen.Module` it still throws a `ValueError`. Any help would be much appreciated 😊.
avital commented 3 years ago

@marcvanzee will be investigating this.

A few guiding questions:

  1. Why would einops be implemented as a Module instead of just a function?
  2. Why is multiple inheritance needed here?
  3. Regardless, this error shouldn't happen. So even if we answer (1) and (2) in a way that means there's a workaround, we should still fix this bug.
levskaya commented 3 years ago

Just noticing this issue for the first time... I've seen similarly weird issues with Mixins and Flax resolved in the past by simply changing the order of the multiple inheritance - e.g. class Reduce(nn.Module, ReduceMixin): to put nn.Module first. I'm not 100% sure this is the same kind of issue that I've seen before w. mixins, but I'd certainly be curious if that would have fixed the issue...

cgarciae commented 2 years ago

I was playing around with mixins to see how they interact with Module.

See experiments Following case works: ```python import flax.linen as nn import jax.numpy as jnp import jax class Mixin: def __call__(self, x): return self.dense(x) class MyModule(nn.Module, Mixin): def setup(self): self.dense = nn.Dense(2) module_a = MyModule() variables = module_a.init(jax.random.PRNGKey(0), jnp.ones((1, 1))) ``` However, passing `setup` to `Mixin` fails: ```python class Mixin: def setup(self): self.dense = nn.Dense(2) def __call__(self, x): return self.dense(x) class MyModule(nn.Module, Mixin): pass # AttributeError: "MyModule" object has no attribute "dense" ``` This is again fixed if `Mixin` is set as the first parent: ```python class Mixin: def setup(self): self.dense = nn.Dense(2) def __call__(self, x): return self.dense(x) class MyModule(Mixin, nn.Module): pass ``` Also, you cannot define `compact` methods on mixins (this is probably expected?): ```python class Mixin: @nn.compact def __call__(self, x): return nn.Dense(2)(x) class MyModule(nn.Module, Mixin): # swapping doesn't help pass ```

Discussion

Based on these experiments the only insight I see is: don't define scope-dependent operations (compact, self.param/variable) in inside mixins as their methods will not be wrapped appropriately. Not sure if there is a way to properly wrap mixin methods in __init_subclass__, either _get_local_method_names is not detecting them or they are not available when __init_subclass__ is called.

carlosgmartin commented 2 months ago

I'm also having trouble with mixins. I suggest adding some documentation about how to make flax modules play well with mixins.