Closed svaningelgem closed 12 months ago
I think overload
- which is already builtin - is what you're looking for. It does a linear scan of the function signatures.
@overload
def func(a, b):
return 1
@func.register
def _(a, b: float = 1.0, c: str = "A"):
return 2
Yes and no...
I tried this by changing all my tests in the test_dispatch.py
file to use overload:
def test_keywords():
class cls: pass
@overload
def func(arg):
return 0
@overload
def func(arg: int):
return 1
@overload
def func(arg: int, extra: Union[int, float]):
return 2
@overload
def func(arg: int, extra: str):
return 3
@overload
def func(arg: int, *, extra: cls):
return 4
assert func("sth") == 0
assert func(0) == func(arg=0) == 1
assert func(0, 0.0) == func(arg=0, extra=0.0) == func(arg=0, extra=0.0) == 2
assert func(0, 0) == func(0, extra=0) == func(arg=0, extra=0) == 2
assert func(0, '') == func(0, extra='') == func(arg=0, extra='') == 3
assert func(0, extra=cls()) == func(arg=0, extra=cls()) == 4
with pytest.raises(DispatchError):
func(0, cls())
Fails with
self = typing.Union, args = (0.0,), kwds = {}
def __call__(self, *args, **kwds):
> raise TypeError(f"Cannot instantiate {self!r}")
E TypeError: Cannot instantiate typing.Union
It seems to largely work, except when varargs come into the picture.
Like:
def test_var_positional():
"""Check that we can overload instance methods with variable positional arguments."""
class cls:
@overload
def func(self):
return 1
@overload()
def func(self, *args: object):
return 2
assert cls().func() == 1
assert cls().func(1) == 2
This always defers to the varargs one.
Or when defaults play:
def test_different_signatures():
@overload
def func(a: int):
return f'int: {a}'
@overload
def func(a: int, b: float = 3.0):
return f'int_float: {a} / {b}'
assert func(1) == 'int: 1'
Fails with:
Expected :'int: 1'
Actual :'int_float: 1 / 3.0'
It seems to largely do what I want it to do, but not entirely. For example the last bit I pasted here: it seems to take the last one, or the one with the varargs... The one that can consume the most, whereas (logically) I would say it should use the first one that is capable of matching the signature.
Ok, changing this:
for sig in reversed(self):
==>
for sig in self:
Solved already a few of the test cases.
def __init__(self, func: Callable):
for name, value in get_type_hints(func).items():
if getattr(value, '__origin__', None) is Union:
func.__annotations__[name] = isa(value.__args__)
elif not callable(value) or isinstance(value, type):
func.__annotations__[name] = isa(value)
self[inspect.signature(func)] = func
def _check(self, param, value, sub: bool = False):
if not sub:
if param.kind == inspect.Parameter.VAR_POSITIONAL:
return all(self._check(param, v, sub=True) for v in value)
if param.kind == inspect.Parameter.VAR_KEYWORD:
return all(self._check(param, v, sub=True) for v in value.values())
return param.annotation is param.empty or param.annotation(value)
def __call__(self, *args, **kwargs):
"""Dispatch to first matching function."""
for sig in reversed(self):
try:
arguments = sig.bind(*args, **kwargs).arguments
except TypeError:
continue
if all(
self._check(param, arguments[name])
for name, param in sig.parameters.items()
if name in arguments
):
return self[sig](*args, **kwargs)
raise DispatchError("No matching functions found")
@tp_overload
def register(self, *args: type) -> Callable: ...
@tp_overload
def register(self, func: Callable) -> Callable: ...
def register(self, *args) -> Callable:
"""Decorator for registering a function."""
if len(args) == 1 and hasattr(args[0], '__annotations__'):
func = args[0]
self.__init__(func)
return self if self.__name__ == func.__name__ else func # type: ignore
return lambda func: self.__setitem__(args, func) or func
This solves:
However, it does not solve this:
@overload
def roshambo(left, right):
return 'tie'
@roshambo.register(scissors, rock)
@roshambo.register(rock, scissors)
def roshambo(left, right):
return 'rock smashes scissors'
Everything is a tie there.
Another thing that isn't supported by overload
: pending types:
def test_unknown_types():
class A: pass
@overload
def func(a: int):
return 1
@func.register
def func(a: "A"):
return 2
assert func(1) == 1
assert func(A()) == 2
==> E NameError: name 'A' is not defined
Hi @svaningelgem
It seems you have added support for several important cases.
Any particular reason why Optional
arguments are out of the scope?
@overload
def f(x: int, c: str = None)
As of v1.10, dispatch supports optionals.
Closing since this seems stalled. Maybe it can be split into separate ideas.
Hi @coady,
I had an issue at work that keyword arguments were not properly matched, so I was diving into the matching engine for the full signature. As
multimethod
only matches on positional arguments, you hinted that I should usemultidispatch
, but also that one fell short of what I needed.Hence I tried to look at other libraries, but seems not much in the market for this kind of routing. I found one called
overload
, which was nearly what I needed it to do. So I took the tests and fixed the code so these tests could run through the newmultidispatch
code.BTW, the only class I touched is
multidispatch
. (and I moved somemultimethod
tests to thetest_method
file as they belong there imho.I want to raise this PR to get your idea about what I did and to get some pointers on how to improve on it.
One thing I'm not very certain about is what I also mentioned in the README:
The last bit happens because "A" is matched with
b
, which should be a float. This is because I'm using the Signature.bind method, but that obviously doesn't take into account the type hintings...If a full blown matching engine needs to be implemented that might lead too far and will slow things down considerably. So I don't know if it's truly worth the effort to do this now.
At least the basics work: match ANY signature, with ANY kind of arguments.
Let's discuss :)