coady / multimethod

Multiple argument dispatching.
https://coady.github.io/multimethod
Other
284 stars 23 forks source link

Bugfix/multidispatch on kwargs #89

Closed svaningelgem closed 12 months ago

svaningelgem commented 1 year ago

Hi @coady,

I had an issue at work that keyword arguments were not properly matched, so I was diving into the matching engine for the full signature. As multimethod only matches on positional arguments, you hinted that I should use multidispatch, but also that one fell short of what I needed.

Hence I tried to look at other libraries, but seems not much in the market for this kind of routing. I found one called overload, which was nearly what I needed it to do. So I took the tests and fixed the code so these tests could run through the new multidispatch code.

BTW, the only class I touched is multidispatch. (and I moved some multimethod tests to the test_method file as they belong there imho.

I want to raise this PR to get your idea about what I did and to get some pointers on how to improve on it.

One thing I'm not very certain about is what I also mentioned in the README:

from multimethod import multidispatch

@multidispatch
def func(a, b):
    return 1

@func.register
def _(a, b: float = 1.0, c: str = "A"):
    return 2

print(func(1, 2))  # 1
print(func(1, "A"))  # 1

The last bit happens because "A" is matched with b, which should be a float. This is because I'm using the Signature.bind method, but that obviously doesn't take into account the type hintings...

If a full blown matching engine needs to be implemented that might lead too far and will slow things down considerably. So I don't know if it's truly worth the effort to do this now.

At least the basics work: match ANY signature, with ANY kind of arguments.

Let's discuss :)

coady commented 1 year ago

I think overload - which is already builtin - is what you're looking for. It does a linear scan of the function signatures.

@overload
def func(a, b):
    return 1

@func.register
def _(a, b: float = 1.0, c: str = "A"):
    return 2
svaningelgem commented 1 year ago

Yes and no...

I tried this by changing all my tests in the test_dispatch.py file to use overload:

def test_keywords():
    class cls: pass

    @overload
    def func(arg):
        return 0

    @overload
    def func(arg: int):
        return 1

    @overload
    def func(arg: int, extra: Union[int, float]):
        return 2

    @overload
    def func(arg: int, extra: str):
        return 3

    @overload
    def func(arg: int, *, extra: cls):
        return 4

    assert func("sth") == 0
    assert func(0) == func(arg=0) == 1
    assert func(0, 0.0) == func(arg=0, extra=0.0) == func(arg=0, extra=0.0) == 2
    assert func(0, 0) == func(0, extra=0) == func(arg=0, extra=0) == 2
    assert func(0, '') == func(0, extra='') == func(arg=0, extra='') == 3
    assert func(0, extra=cls()) == func(arg=0, extra=cls()) == 4

    with pytest.raises(DispatchError):
        func(0, cls())

Fails with

self = typing.Union, args = (0.0,), kwds = {}

    def __call__(self, *args, **kwds):
>       raise TypeError(f"Cannot instantiate {self!r}")
E       TypeError: Cannot instantiate typing.Union

It seems to largely work, except when varargs come into the picture.

Like:

def test_var_positional():
    """Check that we can overload instance methods with variable positional arguments."""

    class cls:
        @overload
        def func(self):
            return 1

        @overload()
        def func(self, *args: object):
            return 2

    assert cls().func() == 1
    assert cls().func(1) == 2

This always defers to the varargs one.

Or when defaults play:

def test_different_signatures():
    @overload
    def func(a: int):
        return f'int: {a}'

    @overload
    def func(a: int, b: float = 3.0):
        return f'int_float: {a} / {b}'

    assert func(1) == 'int: 1'

Fails with:

Expected :'int: 1'
Actual   :'int_float: 1 / 3.0'

It seems to largely do what I want it to do, but not entirely. For example the last bit I pasted here: it seems to take the last one, or the one with the varargs... The one that can consume the most, whereas (logically) I would say it should use the first one that is capable of matching the signature.

svaningelgem commented 1 year ago

Ok, changing this:

        for sig in reversed(self):
==>
        for sig in self:

Solved already a few of the test cases.

svaningelgem commented 1 year ago
    def __init__(self, func: Callable):
        for name, value in get_type_hints(func).items():
            if getattr(value, '__origin__', None) is Union:
                func.__annotations__[name] = isa(value.__args__)
            elif not callable(value) or isinstance(value, type):
                func.__annotations__[name] = isa(value)

        self[inspect.signature(func)] = func

    def _check(self, param, value, sub: bool = False):
        if not sub:
            if param.kind == inspect.Parameter.VAR_POSITIONAL:
                return all(self._check(param, v, sub=True) for v in value)

            if param.kind == inspect.Parameter.VAR_KEYWORD:
                return all(self._check(param, v, sub=True) for v in value.values())

        return param.annotation is param.empty or param.annotation(value)

    def __call__(self, *args, **kwargs):
        """Dispatch to first matching function."""
        for sig in reversed(self):
            try:
                arguments = sig.bind(*args, **kwargs).arguments
            except TypeError:
                continue
            if all(
                self._check(param, arguments[name])
                for name, param in sig.parameters.items()
                if name in arguments
            ):
                return self[sig](*args, **kwargs)
        raise DispatchError("No matching functions found")

    @tp_overload
    def register(self, *args: type) -> Callable: ...

    @tp_overload
    def register(self, func: Callable) -> Callable: ...

    def register(self, *args) -> Callable:
        """Decorator for registering a function."""
        if len(args) == 1 and hasattr(args[0], '__annotations__'):
            func = args[0]
            self.__init__(func)
            return self if self.__name__ == func.__name__ else func  # type: ignore
        return lambda func: self.__setitem__(args, func) or func

This solves:

  1. registration with var-args
  2. union in type hinting
  3. taking the first matching method instead of the last

However, it does not solve this:

    @overload
    def roshambo(left, right):
        return 'tie'

    @roshambo.register(scissors, rock)
    @roshambo.register(rock, scissors)
    def roshambo(left, right):
        return 'rock smashes scissors'

Everything is a tie there.

svaningelgem commented 1 year ago

Another thing that isn't supported by overload: pending types:


def test_unknown_types():
    class A: pass

    @overload
    def func(a: int):
        return 1

    @func.register
    def func(a: "A"):
        return 2

    assert func(1) == 1
    assert func(A()) == 2

==> E NameError: name 'A' is not defined

ipcoder commented 1 year ago

Hi @svaningelgem It seems you have added support for several important cases. Any particular reason why Optional arguments are out of the scope?

@overload
def f(x: int, c: str = None)
coady commented 1 year ago

As of v1.10, dispatch supports optionals.

coady commented 12 months ago

Closing since this seems stalled. Maybe it can be split into separate ideas.