patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
1.93k stars 131 forks source link

Equinox Module methods not recognized by beartype #584

Open EtaoinWu opened 8 months ago

EtaoinWu commented 8 months ago

In the following code:

from beartype import beartype
import equinox as eqx
from jaxtyping import Array, Float

@beartype
class MyClass(eqx.Module):
  x: Float[Array, ""]

  def fn(self, y: bool) -> Float[Array, ""]:
    return self.x + 1

You can actually call MyClass(jnp.array(1.)).fn('not bool') without getting a roar from the bear.

The reason is that beartype, when decorating a class, iterates through each attribute of its __dict__. In our case, MyClass.__dict__['fn'] (different from MyClass.fn!) is a equinox._module._wrap_method, and is not beartype-able.

Current workaround

I use a dirty hack to add beartype to eqx.Modules.

Potential fix

I cannot think of a perfect way to fix this. Here are some thoughts.

def _wrap_method(method):
    @ft.wraps(method)
    def method_(*args, **kwargs):
        return method(*args, **kwargs)
    def _get(self, instance, owner):
        if instance is None:
            return method
        else:
            _method = _module_update_wrapper(
                BoundMethod(method, instance), None, inplace=True
            )
            return _method
    setattr(method_, '__get__', types.MethodType(_get, method_))
    return method_

Not sure if this would break other parts of Equinox.

EtaoinWu commented 8 months ago

The final note seems to break a bunch of stuff in test_module.py.

It doesn't actually work -- a function's __get__ attribute seems to be ignored.

import equinox as eqx

class MyModule(eqx.Module):
    a: int

    def f(self, b):
        return self.a + b

x = MyModule(1)
print(x.f)
f = MyModule.__dict__['f']
print(f.__get__(x, MyModule))

The code outputs:

<bound method MyModule.f of MyModule(a=1)>
BoundMethod(__func__=<function f>, __self__=MyModule(a=1))
patrick-kidger commented 8 months ago

Hmm, interesting! So first of all, the fix is probably to move the @beartype decorator from the class to the method:

class Foo(eqx,Module):
    @beartype
    def bar(self): ...

I'm aware that this isn't what beartype recommends, but in reality this handles 99% of all cases. (And decorating the method like this is actually what I always do as well.)


Tagging @leycec, as this has got me thinking. I'd like to propose that beartype recommend decorating methods instead. On the two points raised in the docs above:

The benefits would be that:

WDYT?

leycec commented 8 months ago

I don't think @beartype would bother to add support for Equinox.

Hah! I'll show you, @EtaoinWu. Friendship is magic. @beartype is definitely prepared to support Big Boss @patrick-kidger and all of his associated JAX madness in whatever way he needs – including rolling out internal support for Equinox-specific wrappers.

After all, @beartype already internally supports non-standard third-party NumPy and Pandera type hints. Equinox is just yet another non-standard third-party thing that @beartype would support.

Ideally, of course, @beartype would provide some sort of plugin API for this sort of thing. But... ain't nobody got that kind of time. I'll just hack it instead. :smile:

I'd like to propose that beartype recommend decorating methods instead.

In response, @beartype would like to propose this Berserk meme.

I think class-specific stuff like typing.Self can be handled by checking the type of the first argument. Indeed you have to be doing this anyway, since Self refers to that type, not to the type of the class in which the method is defined!

Hah! @beartype showed you all. Indeed, @beartype type-checks typing.Self as the type of the class in which the method is defined. Checkmate, @patrick-kidger. This is actually the optimally efficient means of type-checking self-ness. It's also the optimally safe means of type-checking self-ness, because the self parameter passed to bound instance methods can be a deceptive lie (e.g., when externally defined functions monkey-patch themselves into a class via the __get__() descriptor).

Moreover, doing so enables @beartype to additionally type-check typing.Self in all possible contexts – including those that lack access to the self parameter passed to bound instance methods:

@beartype
@dataclass
class BadExampleIsBad(object):
     myself: Self  # <-- Yup. @beartype can type-check this.

     @static
     def make_myself() -> Self:  # <-- Yup. @beartype can type-check this, too.
         return BadExampleIsBad()

     @classmethod
     def ignore_myself(cls: type[Self]) -> Self:  # <-- Still fine.
         return make_myself()

But there are actually many, many reasons apart from PEP 673 (i.e., typing.Self) why decorating the class rather than method is preferable. This includes:

Relevant commentary in the beartype._check.forward.fwdhint submodule includes:

        # If the decorated callable is nested (rather than global) and thus
        # *MAY* have a non-empty local nested scope...
        if bear_call.func_wrappee_is_nested:
            # Attempt to...
            try:
                # Local scope of the decorated callable, localized to improve
                # readability and negligible efficiency when accessed below.
                func_locals = get_func_locals(
                    func=func,

                    # Ignore all lexical scopes in the fully-qualified name of
                    # the decorated callable corresponding to parent classes
                    # lexically nesting the current decorated class containing
                    # that callable (including that class). Why? Because these
                    # classes are *ALL* currently being decorated and thus have
                    # yet to be encapsulated by new stack frames on the call
                    # stack. If these lexical scopes are *NOT* ignored, this
                    # call to get_func_locals() will fail to find the parent
                    # lexical scope of the decorated callable and then raise an
                    # unexpected exception.
                    #
                    # Consider, for example, this nested class decoration of a
                    # fully-qualified "muh_package.Outer" class:
                    #     @beartype
                    #     class Outer(object):
                    #         class Middle(object):
                    #             class Inner(object):
                    #                 def muh_method(self) -> str:
                    #                     return 'Painful API is painful.'
                    #
                    # When @beartype finally recurses into decorating the nested
                    # muh_package.Outer.Middle.Inner.muh_method() method, this
                    # call to get_func_locals() if *NOT* passed this parameter
                    # would naively assume that the parent lexical scope of the
                    # current muh_method() method on the call stack is named
                    # "Inner". Instead, the parent lexical scope of that method
                    # on the call stack is named "muh_package" -- the first
                    # lexical scope enclosing that method that exists on the
                    # call stack. The non-existent "Outer", "Middle", and
                    # "Inner" lexical scopes must *ALL* be silently ignored.
                    func_scope_names_ignore=(
                        0 if cls_stack is None else len(cls_stack)),

                    #FIXME: Consider dynamically calculating exactly how many
                    #additional @beartype-specific frames are ignorable on the first
                    #call to this function, caching that number, and then reusing
                    #that cached number on all subsequent calls to this function.
                    #The current approach employed below of naively hard-coding a
                    #number of frames to ignore was incredibly fragile and had to be
                    #effectively disabled, which hampers runtime efficiency.

                    # Ignore additional frames on the call stack embodying:
                    # * The current call to this function.
                    #
                    # Note that, for safety, we currently avoid ignoring
                    # additional frames that we could technically ignore. These
                    # include:
                    # * The call to the parent
                    #   beartype._check.checkcall.BeartypeCall.reinit() method.
                    # * The call to the parent @beartype.beartype() decorator.
                    #
                    # Why? Because the @beartype codebase has been sufficiently
                    # refactored so as to render any such attempts non-trivial,
                    # fragile, and frankly dangerous.
                    func_stack_frames_ignore=1,
                    exception_cls=exception_cls,
                )
            # If this local scope cannot be found (i.e., if this getter found
            # the lexical scope of the module declaring the decorated callable
            # *before* that of the parent callable or class declaring that
            # callable), then this resolve_hint() function was called *AFTER*
            # rather than *DURING* the declaration of the decorated callable.
            # This implies that that callable is not, in fact, currently being
            # decorated. Instead, that callable was *NEVER* decorated by
            # @beartype but has instead subsequently been passed to this
            # resolve_hint() function after its initial declaration -- typically
            # due to an external caller passing that callable to our public
            # beartype.peps.resolve_pep563() function.
            #
            # In this case, the call stack frame providing this local scope has
            # (almost certainly) already been deleted and is no longer
            # accessible. We have no recourse but to default this local scope to
            # the empty dictionary -- which might be subsequently modified and
            # *CANNOT* thus default to the singleton empty dictionary
            # "DICT_EMPTY" (unlike below).
            except _BeartypeUtilCallableScopeNotFoundException:
                func_locals = {}

            # If the decorated callable is a method transitively defined by a
            # root decorated class, add a pair of local attributes exposing:
            #
            # * The unqualified basename of the root decorated class. Why?
            #   Because this class may be recursively referenced in postponed
            #   type hints and *MUST* thus be exposed to *ALL* postponed type
            #   hints. However, this class is currently being decorated and thus
            #   has yet to be defined in either:
            #   * If this class is module-scoped, the global attribute
            #     dictionary of that module and thus the "func_globals"
            #     dictionary.
            #   * If this class is closure-scoped, the local attribute
            #     dictionary of that closure and thus the "func_locals"
            #     dictionary.
            # * The unqualified basename of the current decorated class. Why?
            #   For similar reasons. Since the current decorated class may be
            #   lexically nested in the root decorated class, the current
            #   decorated class is *NOT* already accessible as either a global
            #   or local. Exposing the current decorated class to a stringified
            #   type hint referencing that class thus requires adding a local
            #   attribute exposing that class.
            #
            # Note that:
            # * *ALL* intermediary classes (i.e., excluding the root decorated
            #   class) lexically nesting the current decorated class are
            #   irrelevant. Intermediary classes are neither module-scoped nor
            #   closure-scoped and thus inaccessible as either globals or locals
            #   in the nested lexical scope of the current decorated class:
            #   e.g.,
            #     # This raises a parser error and is thus *NOT* fine:
            #     #     NameError: name 'muh_type' is not defined
            #     class Outer(object):
            #         class Middle(object):
            #             muh_type = str
            #
            #             class Inner(object):
            #                 def muh_method(self) -> muh_type:
            #                     return 'Dumpster fires are all I see.'
            # * This implicitly overrides any previously declared locals of the
            #   same name. Although non-ideal, this constitutes syntactically
            #   valid Python and is thus *NOT* worth emitting even a non-fatal
            #   warning over: e.g.,
            #     # This is fine... technically.
            #     from beartype import beartype
            #     def muh_closure() -> None:
            #         MuhClass = 'This is horrible, yet fine.'
            #
            #         @beartype
            #         class MuhClass(object):
            #             def muh_method(self) -> str:
            #                 return 'Look away and cringe, everyone!'
            if cls_stack:
                # Root and current decorated classes.
                cls_root = cls_stack[0]
                cls_curr = cls_stack[-1]

The cls_stack mentioned above is the stack of all (possibly nested) classes currently being decorated by @beartype. This is what you provide to @beartype when you decorate classes with @beartype. @beartype otherwise has no access to the class stack and thus cannot reliably do anything. It is sad.

Emoji cat cries for Equinox! :crying_cat_face:

Dataclass-generated init methods may soon no longer correspond to their annotations, see this discourse thread...

Gah! Curse you, @dataclasses.dataclass! Of course, @dataclasses.dataclass already violates core typing standards like PEP 484 in various obvious ways: e.g.,

@dataclass
class C:
    # This is a lie. "mylist" is an instance of "dataclasses.field" rather than "list".
    mylist: list[int] = field(default_factory=list)

Since dataclasses already behaves badly, @beartype kinda just ignores dataclasses as much as it can and hacks around the rest. At this point, more bad behaviour from dataclasses is unsurprising. Sane behaviour would be the surprising thing. Emoji cat continues crying. :crying_cat_face:

...class-level sneakery like the Equinox function-wrappers...

I like the way you think, Dr. Kidger. Yes! Let's do that! Let's do the class-level sneakery thing. Just let me know somewhere how you'd like @beartype to eventually:

We'll make this despicable magic happen yet, boys. :muscle: :bear:

patrick-kidger commented 8 months ago

Ech, that sounds complicated!

Okay, I think we can make this work with just some small tweaks. IIUC, beartype is morally doing something like this:

for key, value in cls.__dict__.items():
    if inspect.isfunction(value):
        setattr(cls, key, beartype(value))

I think it should be enough to change things to:

for key in cls.__dict__.keys():
    value = getattr(cls, key)  # call __get__ to get an actual function, not a function-wrapper
    if inspect.isfunction(value):
        setattr(cls, key, beartype(value))

...that is, once I've merged #587, which adds support for such monkey-patching, by adding a __setattr__ that checks for functions and wraps them into one of Equinox's function-wrappers. :D

leycec commented 8 months ago

Ech, that sounds complicated!

The more you know, the more you know you don't wanna know. @beartype: it's like quantum mechanics that way.

beartype is morally doing something

:rofl:

I think it should be enough to change things to:

This is both clever and obscene. I feel approval.

I also feel trepidation. @beartype does already handle C-based builtin method descriptors wrapping pure-Python unbound methods. This includes @classmethod, @staticmethod, and @property getters, setters, and deleters. I'm pretty sure (but not certain) those descriptors fail to support Equinox's sort of __setattr__()-based monkey-patching. But maybe they do? But probably they don't.

I kinda intuit that your proposed resolution will satisfy the specific use case of Equinox while yet failing the general use case of Python's standard descriptors. Do I really know what I am talking about? The answer may shock you. But probably it won't.

patrick-kidger commented 8 months ago

Hmm. Good point. Perhaps:

for key, value in cls.__dict__.items():
    if not is_classmethod_or_whatever(value):
        value = getattr(cls, key)
    if inspect.isfunction(value):
        setattr(cls, key, beartype(value))

?

That's obviously kind of a hack that happens to work for the builtins and happens to work for Equinox. But I think the above should hit not just cls.__setattr__ but also cls.key.__set__ (well, technically cls.__dict__[key].__set__), if the latter should happen to be defined, so I think that might be as good as it gets.

leycec commented 8 months ago

...heh. Brilliant minds GitHub alike. Coincidentally, that's exactly what I concocted in my bald head while hiking through the frigid wastes of Canada this morning:

  1. Try to do what @beartype currently does for standard builtin decorators.
  2. When that fails, fallback to doing the Equinox-specific thing.

The really nice thing about the Equinox-specific thing is that it should also generalize to arbitrary other third-party packages resembling Equinox. @beartype will then "just work" out-of-the-box without @beartype ...pretty sure this means me needing to explicitly support n different package-specific wrapper strategies. Instead, we just support 1 package-specific wrapper strategy. Thus, I tell everybody:

Do what Equinox does. Then @beartype will support you.

This is win. Thankfully, you were even kind enough to promptly merge #587. Therefore, this is @beartype's roadmap to payback:

  1. I add provisional support for this tonight.
  2. I extensively unit test this next week.
  3. Official support for this lands in @beartype 0.17.0 – to be released before Planet Earth is inevitably sucked into a black hole.

Thanks for being so supportive, Dr. Kidger. The bear will howl on the equinox! :bear: :new_moon_with_face:

leycec commented 8 months ago

Possibly resolved by beartype/beartype@58219ba02be8a. In theory, this now works. In practice, nothing is tested. More importantly...

OMFG!!!!!! This was so shockingly insane. It turns out the brute-force approach outlined above fails to suffice. Reality is a harsh mistress and so is the moon. Doing this uncarefully induces INFINITE FRIGGIN' RECURSION on standard types, including:

Needless to say, I remain both shocked and appalled that @beartype has been uselessly attempting to decorate the object and type superclasses all this time. What a clown show of double facepalms! I put an abrupt halt to that tonight. Let us pretend @beartype never did that.

For additional protection against madness, I've also prohibited dunder attributes like __class__ and __base__ from consideration. I'm still concerned that this monkey-patch renders @beartype susceptible to other infinite races I can't foresee. The @beartype test suite passes – but that doesn't necessarily mean much. Tests always pass, especially when the real world is exploding!

Nonetheless, we're shipping this. If @beartype 0.17.0 blows up PyTorch yet again, I can only hang my head and blame @patrick-kidger.

patrick-kidger commented 8 months ago

:D https://twitter.com/PatrickKidger/status/1725954421505622296

For additional protection against madness, I've also prohibited dunder attributes like class and base from consideration. I'm still concerned that this monkey-patch renders https://github.com/beartype susceptible to other infinite races I can't foresee. The https://github.com/beartype test suite passes – but that doesn't necessarily mean much. Tests always pass, especially when the real world is exploding!

Maybe go the other way, prohibit all magic methods except those on an explicit allow-list? (And for what it's worth, Equinox's wrappers only apply to non-magic methods.)

leycec commented 8 months ago

https://twitter.com/PatrickKidger/status/1725954421505622296

:rage4: :innocent: :rage4:

Rejoice! Microsoft just hired Sam Altman, restoring balance to the ML Force... wait. What were we talking about again? Is this Reddit? Where even am I? Oh, right. GitHub. It's still happening. </awkward>

So. Gentlemen and scuba divers alike, I have solved all the recursion complaints. @beartype is now robust yet again against pernicious edge cases that are too shameful for me to publicly exhibit here. Although I have yet to explicitly test against Equinox, everything should now work as expected. If anyone with more free time than is healthy would like to disprove these lies I am telling you, please test this for me:

pip install git+https://github.com/beartype/beartype.git@9faf1ecfc0fc3f26ab0de9eab710354476990cb4

To further harden this feature against "Surprise. It's Johnnnny!", I'll also be adding Equinox-specific unit tests to @beartype over the next several days. We inch closer to officially solving everything. :partying_face:

leycec commented 6 months ago

Santa Bear Claws has come to town. In other words, this long-standing issue was resolved by beartype/beartype@9faf1ecfc0fc3 a month and a half ago. This is the first spare moment I've had to genuinely test that commit against Equinox. After all, Xenoblade Chronicles: Definitive Edition doesn't play itself.

Thankfully, all is now full of worky. Behold! :magic_wand:

import equinox as eqx
from beartype import beartype
from jax import numpy as jnp
from jaxtyping import (
    Array,
    Float,
)

@beartype
class MyClass(eqx.Module):
  x: Float[Array, ""]

  def fn(self, y: bool) -> Float[Array, ""]:
    return self.x + 1

MyClass(jnp.array(1.)).fn('not bool')

...which now raises the expected type-checking violation:

Traceback (most recent call last):
  File "/home/leycec/tmp/mopy.py", line 18, in <module>
    MyClass(jnp.array(1.)).fn('not bool')
  File "/home/leycec/py/conda/envs/ionyou_dev/lib/python3.11/site-packages/equinox/_module.py", line 875, in __call__
    return self.__func__(self.__self__, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<@beartype(__main__.MyClass.fn) at 0x7fa9bcdfb420>", line 22, in fn
beartype.roar.BeartypeCallHintParamViolation: Method __main__.MyClass.fn() parameter y='not bool' violates
type hint <class 'bool'>, as str 'not bool' not instance of bool.

As the traceback suggests, @beartype now defers to Equinox's expert opinion on the matter. For everyone's safety, I've added an integration test to @beartype's test suite exercising this against regressions. This is the way the AI was won.

@patrick-kidger: Thanks again for all the :cupid: from the Google Open Source Programs Office. This issue can now be safely closed. Happy New Year from the frigid wastelands of Mosquito Land Ontario, Canada! :hugs: