patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.12k stars 142 forks source link

Unexpected results with __subclasshook__ #893

Closed SimonKoop closed 1 week ago

SimonKoop commented 1 week ago

When using a __subclasshook__ with equinox modules, it may happen that issubclass returns True on pairs of classes where normally (if they weren't equinox modules) this would result in False. For example, in the following code, the print statement at the end will return the same boolean as set in with_eqx.

with_eqx:bool = False

if with_eqx:
    import equinox as eqx

    Module = eqx.Module
    StatefulModule = eqx.nn.StatefulLayer
else:
    class Module:
        pass 

    class StatefulModule(Module):
        pass

class RegularLayer(Module):
    pass

class MaybeStatefulLayer(StatefulModule):
    pass

class SpecificRegularLayer(RegularLayer):
    pass

class SpecialModuleCollection(Module):
    @classmethod
    def __subclasshook__(cls, maybe_subclass):
        if issubclass(maybe_subclass, (RegularLayer, MaybeStatefulLayer)):
            return True
        return NotImplemented
    pass

class SpecificSpecialModule(StatefulModule, SpecialModuleCollection):
    pass

print(issubclass(SpecificRegularLayer, StatefulModule))  # False if with_eqx is set to False, otherwise True
patrick-kidger commented 1 week ago

Oh wow, now that really is super weird.

This actually doesn't look to be a bug in Equinox, but a bug in Python's abc itself. Here's a reproducer that doesn't use Equinox at all:

import abc

class X(metaclass=abc.ABCMeta):
    pass

class Y(X):
    pass

class Z(X):
    pass

define_magic = True  # try toggling this value
if define_magic:
    class Magic(Y):
        @classmethod
        def __subclasshook__(cls, maybe_subclass):
            if issubclass(maybe_subclass, Z):
                return True
            return NotImplemented

assert issubclass(Z, Y) == define_magic

(FWIW using Python 3.11.9 in my case.)

I'd suggest verifying if this is still present in recent Python releases, and if so then opening a bug against Python itself.

SimonKoop commented 1 week ago

Thank you for your reply! Indeed, the bug seems to still exist in Python 3.13, so I'll file a bug report there.