patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.06k stars 134 forks source link

Problems with access to class variables in __init_subclass__ during instance creation for equinox Modules #858

Open SimonKoop opened 2 weeks ago

SimonKoop commented 2 weeks ago

Since version 0.11.6, the following results in a strange AttributeError:

import jax
import equinox as eqx

class Parent(eqx.Module):
    abs_cls_var: eqx.AbstractClassVar[str]

    def __init__(self, **kwargs):
        pass

    def __init_subclass__(cls):
        """__init_subclass__ 
        tries to access cls.abs_cls_var
        """
        print(cls.abs_cls_var)

class Child(Parent):
    abs_cls_var = 'w0'

Child()

The attribute error happens in the last line (the creation of an instance of Child and the stack trace in Google Colab goes

AttributeError                            Traceback (most recent call last)

[<ipython-input-7-60f7e23927a5>](https://localhost:8080/#) in <cell line: 1>()
----> 1 Child()

    [... skipping hidden 4 frame]

6 frames

[/usr/local/lib/python3.10/dist-packages/equinox/_module.py](https://localhost:8080/#) in _make_initable_wrapper(cls)
    788 def _make_initable_wrapper(cls: _ActualModuleMeta) -> _ActualModuleMeta:
    789     post_init = getattr(cls, "__post_init__", None)
--> 790     return _make_initable(cls, cls.__init__, post_init, wraps=False)
    791 
    792 

[/usr/local/lib/python3.10/dist-packages/equinox/_module.py](https://localhost:8080/#) in _make_initable(***failed resolving arguments***)
    806         field_names = {field.name for field in dataclasses.fields(cls)}
    807 
--> 808     class _InitableModule(cls, _Initable):
    809         pass
    810 

[/usr/local/lib/python3.10/dist-packages/equinox/_module.py](https://localhost:8080/#) in __new__(mcs, name, bases, dict_, strict, **kwargs)
    200 
    201         # [Step 1] Create the class as normal.
--> 202         cls = super().__new__(mcs, name, bases, dict_, **kwargs)
    203         # [Step 2] Arrange for bound methods to be treated as PyTrees as well. This
    204         # ensures that

[/usr/local/lib/python3.10/dist-packages/equinox/_better_abstract.py](https://localhost:8080/#) in __new__(mcs, name, bases, namespace, **kwargs)
    176 
    177     def __new__(mcs, name, bases, namespace, /, **kwargs):
--> 178         cls = super().__new__(mcs, name, bases, namespace, **kwargs)
    179 
    180         # We don't try and check that our AbstractVars and AbstractClassVars are

[/usr/lib/python3.10/abc.py](https://localhost:8080/#) in __new__(mcls, name, bases, namespace, **kwargs)
    104         """
    105         def __new__(mcls, name, bases, namespace, **kwargs):
--> 106             cls = super().__new__(mcls, name, bases, namespace, **kwargs)
    107             _abc_init(cls)
    108             return cls

[<ipython-input-6-b018fe5e563f>](https://localhost:8080/#) in __init_subclass__(cls)
      9         tries to access cls.abs_cls_var
     10         """
---> 11         print(cls.abs_cls_var)
     12 
     13 

[/usr/local/lib/python3.10/dist-packages/equinox/_module.py](https://localhost:8080/#) in __getattribute__(cls, item)
    610     # `module_update_wrapper`, but if `dataclass` sees it then it tries to follow it.
    611     def __getattribute__(cls, item):
--> 612         value = super().__getattribute__(item)
    613         if (
    614             item == "__wrapped__"

AttributeError: type object '_InitableModule' has no attribute 'abs_cls_var'

This problem did not occur in Equinox version 0.11.5 The problem also doesn't occur if you comment-out the __init__ method in Parent.

patrick-kidger commented 2 weeks ago

Thanks for the report! I've just pushed a commit, so this should be fixed on the latest HEAD.

SimonKoop commented 2 weeks ago

Wow, that was quick! Thanks for fixing this so swiftly!