beartype / plum

Multiple dispatch in Python
https://beartype.github.io/plum
MIT License
527 stars 25 forks source link

Bug with forward reference #81

Open PhilipVinc opened 1 year ago

PhilipVinc commented 1 year ago

MWE reduced from #80 .

The bug appears to arise because Filter defined in flax (see here) is defined with a forward reference, but it's then defined so I would expect this to work?

Filter = Union[bool, str, typing.Collection[str], 'DenyList']

# When conditioning on filters we require explicit boolean comparisons.
# pylint: disable=g-bool-id-comparison

@dataclasses.dataclass(frozen=True, eq=True)
class DenyList:
 ...
from plum import dispatch
from flax.core.scope import Filter

@dispatch
def f(a: Filter):
    return 2

f(1)

---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Input In [10], in <cell line: 1>()
----> 1 f(1)

File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/plum/function.py:344, in Function.__call__(self, *args, **kw_args)
    340 def __call__(self, *args, **kw_args):
    341     # Before attempting to use the cache, resolve any unresolved registrations. Use
    342     # an `if`-statement to speed up the common case.
    343     if self._pending:
--> 344         self._resolve_pending_registrations()
    346     # Attempt to use the cache based on the types of the arguments.
    347     types = tuple(map(type, args))

File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/plum/function.py:222, in Function._resolve_pending_registrations(self)
    220 # Obtain the signature if it is not available.
    221 if signature is None:
--> 222     signature = extract_signature(f, precedence=precedence)
    223 else:
    224     # Ensure that the implementation is `f`, but make a copy before
    225     # mutating.
    226     signature = signature.__copy__()

File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/plum/signature.py:188, in extract_signature(f, precedence)
    185     # Override the `__annotations__` attribute, since `resolve_pep563` modifies
    186     # `f` too.
    187     print(f"exract sign {f}")
--> 188     for k, v in typing.get_type_hints(f).items():
    189         f.__annotations__[k] = v
    191 # Extract specification.

File ~/.pyenv/versions/3.10.6/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/typing.py:1871, in get_type_hints(obj, globalns, localns, include_extras)
   1863 if isinstance(value, str):
   1864     # class-level forward refs were handled above, this must be either
   1865     # a module-level annotation or a function argument annotation
   1866     value = ForwardRef(
   1867         value,
   1868         is_argument=not isinstance(obj, types.ModuleType),
   1869         is_class=False,
   1870     )
-> 1871 value = _eval_type(value, globalns, localns)
   1872 if name in defaults and defaults[name] is None:
   1873     value = Optional[value]

File ~/.pyenv/versions/3.10.6/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/typing.py:329, in _eval_type(t, globalns, localns, recursive_guard)
    327     return t._evaluate(globalns, localns, recursive_guard)
    328 if isinstance(t, (_GenericAlias, GenericAlias, types.UnionType)):
--> 329     ev_args = tuple(_eval_type(a, globalns, localns, recursive_guard) for a in t.__args__)
    330     if ev_args == t.__args__:
    331         return t

File ~/.pyenv/versions/3.10.6/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/typing.py:329, in <genexpr>(.0)
    327     return t._evaluate(globalns, localns, recursive_guard)
    328 if isinstance(t, (_GenericAlias, GenericAlias, types.UnionType)):
--> 329     ev_args = tuple(_eval_type(a, globalns, localns, recursive_guard) for a in t.__args__)
    330     if ev_args == t.__args__:
    331         return t

File ~/.pyenv/versions/3.10.6/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/typing.py:327, in _eval_type(t, globalns, localns, recursive_guard)
    321 """Evaluate all forward references in the given type t.
    322 For use of globalns and localns see the docstring for get_type_hints().
    323 recursive_guard is used to prevent infinite recursion with a recursive
    324 ForwardRef.
    325 """
    326 if isinstance(t, ForwardRef):
--> 327     return t._evaluate(globalns, localns, recursive_guard)
    328 if isinstance(t, (_GenericAlias, GenericAlias, types.UnionType)):
    329     ev_args = tuple(_eval_type(a, globalns, localns, recursive_guard) for a in t.__args__)

File ~/.pyenv/versions/3.10.6/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/typing.py:694, in ForwardRef._evaluate(self, globalns, localns, recursive_guard)
    689 if self.__forward_module__ is not None:
    690     globalns = getattr(
    691         sys.modules.get(self.__forward_module__, None), '__dict__', globalns
    692     )
    693 type_ = _type_check(
--> 694     eval(self.__forward_code__, globalns, localns),
    695     "Forward references must evaluate to types.",
    696     is_argument=self.__forward_is_argument__,
    697     allow_special_forms=self.__forward_is_class__,
    698 )
    699 self.__forward_value__ = _eval_type(
    700     type_, globalns, localns, recursive_guard | {self.__forward_arg__}
    701 )
    702 self.__forward_evaluated__ = True

File <string>:1, in <module>

NameError: name 'DenyList' is not defined
wesselb commented 1 year ago

Hmm, this is not an ideal solution, but explicitly importing DenyList seems to resolve the error:

from plum import dispatch
from flax.core.scope import Filter, DenyList

@dispatch
def f(a: Filter):
    return 2

f(1)

Could you check whether that correctly works within netket?

PhilipVinc commented 1 year ago

Yes, importing in the session DensyList makes it work. Importing DenyList in the netket package itself (not in the session) does not work.

I imagine that this is because typing is for some reason evaluating the signature in the current scope instead of the scope where the lazy thing is defined...

PhilipVinc commented 1 year ago

Do you or our bearish friend have any idea of what to do? Or is this a bug in flax? I'm not very familiar with those string-y signatures...

wesselb commented 1 year ago

I imagine that this is because typing is for some reason evaluating the signature in the current scope instead of the scope where the lazy thing is defined...

I think something like this is going on, though I'm not 100% sure.

When I'm home I'll try the netket example to see if that can be made to work!

PhilipVinc commented 1 year ago

As a simple reproducer in netket you can use

import netket as nk
hi=nk.hilbert.Spin(0.5); ma=nk.models.RBM()
vs=nk.vqs.ExactState(hi, ma)
vs.expect_and_grad(nk.operator.spin.sigmam(hi, 0))

The error is triggered by this definition and possibly others using this CollectionFilter (which is not needed, btw.. it's a keyword argument anyway, it was just there for correctly documenting what type it should be).

Still this is an annoying bug... I think there should be a way to treat these lazy declarations correctly...

wesselb commented 12 months ago

@PhilipVinc Are this issue and #80 still current, or have you managed to resolve the problems since?

PhilipVinc commented 12 months ago

I fixed it in netket by importing the definition of that type, but I think it's not a good solution and therefore this is still a bug

wesselb commented 12 months ago

Glad to hear that you've fixed it by importing the definition. I agree that this should just work, but I'm also not sure how that could be done. Let's leave the issue open then.