patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.13k stars 143 forks source link

Bug: typing issue due to `__getattribute__` #775

Open nstarman opened 4 months ago

nstarman commented 4 months ago

Discovered in https://github.com/GalacticDynamics/galax/pull/377, when jaxtyping's run-time type-checking is turned on, Module.__getattribute__ is not set up to allow for modules to be Generic.

The traceback looks like:

../../python3.11/typing.py:1834: in __class_getitem__
    for param in cls.__parameters__:
        cls        = <class 'ParametricClass'>
        params     = (~T,)
../../python3.11/site-packages/equinox/_module.py:582: in __getattribute__
    value = super().__getattribute__(item)
E   AttributeError: type object 'ParametricClass' has no attribute '__parameters__'
        __class__  = <class 'equinox._module._ModuleMeta'>
        cls        = <class 'ParametricClass'>
        item       = '__parameters__'

I think __parameters__ might need to be special-cased.

patrick-kidger commented 4 months ago

Do you have a MWE?

(For what it's worth I use generics successfully with Equinox elsewhere.)

nstarman commented 4 months ago

Do you have a MWE?

I'll try to make one.

(For what it's worth I use generics successfully with Equinox elsewhere.)

Do you have jaxtyping + beartype on? Beartype is hit in the traceback when it calls typing. _generic_class_getitem, which is where the failed __parameters__ attribute retrieval originates.

patrick-kidger commented 4 months ago

Cheers! FWIW I do often also combine Equinox + jaxtyping + beartype. Admittedly that is now a more complicated stack (beartype especially), so I'm definitely willing to believe something goes wrong :D

nstarman commented 4 months ago

Hmm. It's challenging to reproduce the failure I'm seeing in https://github.com/GalacticDynamics/galax/pull/377. The obvious minimal example doesn't raise the same error.

from typing import Generic, TypeVar

import equinox as eqx
from beartype import beartype as typechecker
from jaxtyping import jaxtyped
from typeguard import typechecked as typechecker

T = TypeVar("T")

@jaxtyped(typechecker=typechecker)
class Parametric(eqx.Module, Generic[T]):
    value: T

@jaxtyped(typechecker=typechecker)
def function(parametric: Parametric[T]) -> Parametric[T]:
    return parametric