beartype / plum

Multiple dispatch in Python
https://beartype.github.io/plum
MIT License
533 stars 24 forks source link

Plum 2 and np.typing.NDArray #74

Open francesco-ballarin opened 1 year ago

francesco-ballarin commented 1 year ago

Hi @wesselb I am testing out https://github.com/wesselb/plum/pull/73 due to the improved numpy support. In the very first example there, you show how to dispatch based on shape and type. In my use cases, I am only interested in dispatching based on type. Since that is actually supported by np.typing.NDArray, I thought I could just use that, but that doesn't seem to be the case.

I report a minimal example below, in case you want to have a look before the plum 2 release.

from plum import dispatch, parametric
from typing import Any, Optional, Tuple, Union

import numpy as np
import numpy.typing

class NDArrayMeta(type):
    def __instancecheck__(self, x):
        if self.concrete:
            shape, dtype = self.type_parameter
        else:
            shape, dtype = None, None
        return (
            isinstance(x, np.ndarray)
            and (shape is None or x.shape == shape)
            and (dtype is None or x.dtype == dtype)
        )

@parametric
class NDArray(np.ndarray, metaclass=NDArrayMeta):
    @classmethod
    @dispatch
    def __init_type_parameter__(
        cls,
        shape: Optional[Tuple[int, ...]],
        dtype: Optional[Any],
    ):
        """Validate the type parameter."""
        return shape, dtype

    @classmethod
    @dispatch
    def __le_type_parameter__(
        cls,
        left: Tuple[Optional[Tuple[int, ...]], Optional[Any]],
        right: Tuple[Optional[Tuple[int, ...]], Optional[Any]],
    ):
        """Define an order on type parameters. That is, check whether
        `left <= right` or not."""
        shape_left, dtype_left = left
        shape_right, dtype_right = right
        return (
            (shape_right is None or shape_left == shape_right)
            and (dtype_right is None or dtype_left == dtype_right)
        )

@dispatch
def f(x: NDArray[None, np.int32]):
    print("An int array!")

@dispatch
def f(x: NDArray[None, np.float64]):
    print("A float array!")

print("BEGIN f")
f(np.ones((3, 3), np.int32))
f(np.ones((2, 2), np.float64))
print("END f")

@dispatch
def g(x: np.typing.NDArray[np.int32]):
    print("An int array!")

@dispatch
def g(x: np.typing.NDArray[np.float64]):
    print("A float array!")

print("BEGIN g")
g(np.ones((3, 3), np.int32))
g(np.ones((2, 2), np.float64))
print("END g")

and the corresponding output

BEGIN f
An int array!
A float array!
END f
BEGIN g
/usr/local/lib/python3.11/dist-packages/plum/signature.py:203: UserWarning: Could not resolve the type hint of `numpy.ndarray[typing.Any, numpy.dtype[numpy.int32]]`. I have ended the resolution here to not make your code break, but some types might not be working correctly. Please open an issue at https://github.com/wesselb/plum.
  annotation = resolve_type_hint(p.annotation)
/usr/local/lib/python3.11/dist-packages/plum/type.py:261: UserWarning: Could not resolve the type hint of `numpy.ndarray[typing.Any, numpy.dtype[numpy.int32]]`. I have ended the resolution here to not make your code break, but some types might not be working correctly. Please open an issue at https://github.com/wesselb/plum.
  return _is_faithful(resolve_type_hint(x))
/usr/local/lib/python3.11/dist-packages/plum/type.py:261: UserWarning: Could not determine whether `numpy.ndarray[typing.Any, numpy.dtype[numpy.int32]]` is faithful or not. I have concluded that the type is not faithful, so your code might run with subpar performance. Please open an issue at https://github.com/wesselb/plum.
  return _is_faithful(resolve_type_hint(x))
/usr/local/lib/python3.11/dist-packages/plum/signature.py:203: UserWarning: Could not resolve the type hint of `numpy.ndarray[typing.Any, numpy.dtype[numpy.float64]]`. I have ended the resolution here to not make your code break, but some types might not be working correctly. Please open an issue at https://github.com/wesselb/plum.
  annotation = resolve_type_hint(p.annotation)
/usr/local/lib/python3.11/dist-packages/plum/type.py:261: UserWarning: Could not resolve the type hint of `numpy.ndarray[typing.Any, numpy.dtype[numpy.float64]]`. I have ended the resolution here to not make your code break, but some types might not be working correctly. Please open an issue at https://github.com/wesselb/plum.
  return _is_faithful(resolve_type_hint(x))
/usr/local/lib/python3.11/dist-packages/plum/type.py:261: UserWarning: Could not determine whether `numpy.ndarray[typing.Any, numpy.dtype[numpy.float64]]` is faithful or not. I have concluded that the type is not faithful, so your code might run with subpar performance. Please open an issue at https://github.com/wesselb/plum.
  return _is_faithful(resolve_type_hint(x))
Traceback (most recent call last):
  File "/tmp/p.py", line 72, in <module>
    g(np.ones((3, 3), np.int32))
  File "/usr/local/lib/python3.11/dist-packages/plum/function.py", line 342, in __call__
    self._resolve_pending_registrations()
  File "/usr/local/lib/python3.11/dist-packages/plum/function.py", line 237, in _resolve_pending_registrations
    self._resolver.register(subsignature)
  File "/usr/local/lib/python3.11/dist-packages/plum/resolver.py", line 58, in register
    existing = [s == signature for s in self.signatures]
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/plum/resolver.py", line 58, in <listcomp>
    existing = [s == signature for s in self.signatures]
                ^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/plum/util.py", line 132, in __eq__
    return self <= other <= self
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/plum/signature.py", line 132, in __le__
    [TypeHint(x) <= TypeHint(y) for x, y in zip(self_types, other_types)]
  File "/usr/local/lib/python3.11/dist-packages/plum/signature.py", line 132, in <listcomp>
    [TypeHint(x) <= TypeHint(y) for x, y in zip(self_types, other_types)]
     ^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/beartype/door/_doormeta.py", line 148, in __call__
    _HINT_KEY_TO_WRAPPER.cache_or_get_cached_func_return_passed_arg(
  File "/usr/local/lib/python3.11/dist-packages/beartype/_util/cache/map/utilmapbig.py", line 231, in cache_or_get_cached_func_return_passed_arg
    value = value_factory(arg)
            ^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/beartype/door/_doormeta.py", line 220, in _make_wrapper
    wrapper_subclass = get_typehint_subclass(hint)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/beartype/door/_doordata.py", line 108, in get_typehint_subclass
    raise BeartypeDoorNonpepException(
beartype.roar.BeartypeDoorNonpepException: Type hint numpy.ndarray[typing.Any, numpy.dtype[numpy.int32]] invalid (i.e., either PEP-noncompliant or PEP-compliant but currently unsupported by "beartype.door.TypeHint").

Thanks!

wesselb commented 1 year ago

Hey @francesco-ballarin,

Thanks for opening an issue about this. :) You're right that numpy.typing currently won't work. The problem is that numpy.typing types are unfortunately not functional themselves:

>>> isinstance(1, npt.NDArray[int])
TypeError: isinstance() argument 2 cannot be a parameterized generic

They are type hints like objects from typing, which won't of their own, but need additional support. @beartype seems to come very close, but unfortunately doesn't quite get it right:

>>> from beartype.door import is_bearable, TypeHint

>>> is_bearable(np.ones(1, int), npt.NDArray[int])   # Nice!
True

>>> TypeHint(npt.NDArray[int])   # Nice!
TypeHint(numpy.ndarray[typing.Any, numpy.dtype[int]])

>>> TypeHint(npt.NDArray[int]) == TypeHint(npt.NDArray[float])   # :(
TypeHint(npt.NDArray[int]) == TypeHint(npt.NDArray[float])

The reason that you're getting a beartype.roar.BeartypeDoorNonpepException whereas I'm not might be due to a different Python version. (I ran the above using Python 3.9.) EDIT: This seems to be the case; see the issue linked below.

I've opened an issue on @beartype to see what @leycec thinks about this.