Closed PhilipVinc closed 1 year ago
Hey @PhilipVinc,
Sorry for not getting back to this earlier. I'm a bit busy at the moment.
I fully agree that being able to dispatch on the dimensionality of an array would be an amazing feature to have. Your proposed solution seems like something that could be made to work.
Here's one way of doing this. First, we make type_of
extendable by using dispatch.
@_dispatch
def type_of(obj):
"""Get the Plum type of an object.
Args:
obj (object): Object to get type of.
Returns
ptype: Plum type of `obj`.
"""
if isinstance(obj, list):
return List(_types_of_iterable(obj))
if isinstance(obj, tuple):
return Tuple(*(type_of(x) for x in obj))
return ptype(type(obj))
Then this is possible:
import numpy as np
from plum import dispatch, parametric, type_of
@parametric
class NPArray(np.ndarray):
pass
@type_of.dispatch
def type_of(x: np.ndarray):
return NPArray[x.ndim]
@dispatch
def f(x: NPArray[1]):
print("Vector!")
@dispatch
def f(x: NPArray[2]):
print("Matrix!")
# Plum currently avoids unnecessary `type_of` calls. We would enable the use
# of `type_of` by default. Here's a hack that works for now:
f._parametric = True
This would be the outcome:
>>> f(np.random.randn(5))
Vector!
>>> f(np.random.randn(5, 5))
Matrix!
>>> f(np.random.randn(5, 5, 5))
NotFoundLookupError: For function "f", signature Signature(plum.parametric.NPArray[3]) could not be resolved.
These extensions might be best off in LAB. For example, you would then import lab.jax as B
and def f(x: B.JAXArray[2])
or def f(x: B.Array[2])
.
How does this sound?
I've made the above change in the most recent release, 1.1.1
, and added an entry to the README. I'll leave this issue open and close it once good implementations of backend-specific array types (like JAXArray[1]
) are available.
Somewhat related to the above; do you consider it feasible/possible/desirable to be able to dispatch on the basis of (a computable subset of) the value of arguments? Im trying to do something akin (though not literally identical) to dispatch based on array shape. In practice, a finite amount of concrete shapes would be in play; but I cant enumerate them upfront, since it depends on the use of the library. That is, id like to match my method dispatch on a computable function of such a hashable type (such as a shape).
@dispatch(property=lambda x: x.shape, condition=lambda s: s < (3, 4, 5))
def f(x: np.ndarray):
print("This array has shape smaller than (3, 4, 5)")
Computing the condition might be expensive; but it should only happen once for every shape that occurs; and then it should just be a dict lookup based on a hashable type; just like any other dynamic dispatched call.
You can already do it. You need to define a parametric type with a custom type_of and then use it to define a dispatch rule.
See this test, where I do it for the number of dimensions (but you can generalise it to work on shapes) https://github.com/wesselb/plum/blob/fb29f9723f4b006457c74beb4cb09c811c1c983a/tests/test_parametric.py#L680
@parametric(runtime_type_of=True)
class NPArray(np.ndarray):
pass
@type_of.dispatch
def type_of_extension(x: np.ndarray):
return NPArray[x.ndim]
@dispatch
def f(x: NPArray[1]):
return "vector"
@dispatch
def f(x: NPArray[2]):
return "matrix"
assert f(np.random.randn(10)) == "vector"
assert f(np.random.randn(10, 10)) == "matrix"
with pytest.raises(NotFoundLookupError):
f(np.random.randn(10, 10, 10))
Ah sorry, now I see what you are asking is a bit different.. I think that implementing this is going to be a bit though...
Yeah... im not seeing any other libraries that dispatch in any way on object value; just purely on type (this example of dispatching based on array dimension is the first example im seeing to the contrary).
So either this is a great idea; or one thats going to blow up in your face once you think about it more / work with it more?
For the use case I have in mind, the hashable attribute is defacto a dynamic extension of the typing system though; and while my use case isnt precisely identical to dispatching based on array shape; I can infact think of plenty of legitimate uses for that as well (say, specializing a 2d/3d cross product or small matrix determinants).
Certainly in many array based languages, the shape of the array is effectively regarded as part of the type description. But some syntax to more flexibly match ranges of values (or general computable functions thereof), rather than just single types, or collections thereof, would be a requirement for me; and I suppose something that isnt currently covered by plum yet from what I can tell.
Hey @PhilipVinc and @EelcoHoogendoorn,
Ah @PhilipVinc says, I think this might be challenging, but it could conceivably be done.
I've hacked something very quick together. Would something like the below be what you're after?
import numpy as np
from plum import dispatch, parametric, type_of
from plum.parametric import CovariantMeta
class DynamicTypeParameter:
"""A type parameter which contains a function `check` which can check whether
another type parameter is a type subparameter."""
def __init__(self, parameter, check):
self.parameter = parameter
self.check = check
def __str__(self):
return str(self.parameter)
def __repr__(self):
return repr(self.parameter)
_original_is_sub_type_parameter = CovariantMeta._is_sub_type_parameter
def _new_is_sub_type_parameter(cls, par_cls, subclass, par_subclass):
# Resolve dynamic type parameter for the parent class.
if len(par_cls) == 1 and isinstance(par_cls[0], DynamicTypeParameter):
par_check = par_cls[0].check
par_parameter = (par_cls[0].parameter,)
else:
par_check = None
# Resolve dynamic type parameter for the subclass.
if len(par_subclass) == 1 and isinstance(par_subclass[0], DynamicTypeParameter):
par_subclass = (par_subclass[0].parameter,)
if par_check:
# Use the dynamic check.
return par_check(*par_subclass)
else:
return _original_is_sub_type_parameter(cls, par_cls, subclass, par_subclass)
CovariantMeta._is_sub_type_parameter = _new_is_sub_type_parameter
@parametric(runtime_type_of=True)
class _NPArray(np.ndarray):
"""A type for NumPy arrays where the type parameter specifies the number of
dimensions."""
class _NPArrayClass:
"""Just some sugar to make indexing with square bracket work."""
def __getitem__(self, shape):
def check_subshape(other_shape):
try:
other_shape = tuple(other_shape)
except TypeError:
return False
return len(other_shape) == len(shape) and all(
s1 <= s2 for s1, s2 in zip(other_shape, shape)
)
return _NPArray[DynamicTypeParameter(shape, check_subshape)]
NPArray = _NPArrayClass()
@type_of.dispatch
def type_of(x: np.ndarray):
# Hook into Plum's type inference system to produce an appropriate instance of
# `NPArray` for NumPy arrays.
return NPArray[x.shape]
@dispatch
def f(x: NPArray[(10,)]):
return "vector of (10,)"
@dispatch
def f(x: NPArray[(10, 20)]):
return "matrix of (10, 20)"
print(f(np.random.randn(8)))
# vector of (10,)
print(f(np.random.randn(11)))
# For function "f", signature Signature(__main__._NPArray[(11,)]) could not be resolved.
print(f(np.random.randn(8, 18)))
# matrix of (10, 10)
print(f(np.random.randn(8, 22)))
# NotFoundLookupError: For function "f", signature Signature(__main__._NPArray[(8, 22)]) could not be resolved.
Hey @wesselb ; thats looking pretty neat. The ability to specify some kind of computable matching function is quite critical for my actual application though. Though I suppose that would be a significant departure from the current syntax, as parsed from the type annotations of the signature itself.
Still; one could imagine parsing the signature as the default; which would correspond to a more flexible syntax as illustrated by the below (where attribute maps from the input to a hashable attribute of the input, and condition from said attribute to a bool):
@dispatch(attribute=lambda x: type(x), condition=lambda a: a == int)
def f(x):
print("This is an int")
# the above could be seen as a verbose form of the below
@dispatch
def f(x: int):
print("This is an int")
Note: in the mockup code I cobbles together, the attribute selection function is fed into the dispatcher via its constructor rather than the annotation call; since it tends to be the same across a whole family of functions, so no need to repeat that all the time.
I was wondering if this is considered bad style, to dispatch based on the object value; but according to wikipedia: https://en.wikipedia.org/wiki/Multiple_dispatch
Multiple dispatch or multimethods is a feature of some programming languages in which a function or method can be dynamically dispatched based on the run-time (dynamic) type or, in the more general case, some other attribute of more than one of its [arguments](https://en.wikipedia.org/wiki/Parameter_(computer_programming)).
So its not unheard of... even if it havnt seen it in the python ecosystem before.
The ability to specify some kind of computable matching function is quite critical for my actual application though.
Would something like how the example works suffice, or would you need something more flexible? To clarify, in the example, the second argument to DynamicTypeParameter
is a function which checks if a given type parameter is a type subparameter. This is the relevant bit of the example:
...
def check_subshape(other_shape):
try:
other_shape = tuple(other_shape)
except TypeError:
return False
return len(other_shape) == len(shape) and all(
s1 <= s2 for s1, s2 in zip(other_shape, shape)
)
return _NPArray[DynamicTypeParameter(shape, check_subshape)]
...
It is true that all information would have to be contained in the type parameter of the object, which is more limited than the general case which you outline where the function could take in the object itself.
I was wondering if this is considered bad style, to dispatch based on the object value; but according to wikipedia: (...)
This is really interesting. I think one could conceivably hack this together, but it might be challenging to make it really work well. For example, at the moment, types are hashable, which enables caching, and caching is crucial to obtain reasonable performance. If you use an arbitrary function to check whether an object belongs to the type family, then caching in this way wouldn't be possible anymore, because the object might not be hashable. If the object were hashable, then you could make it part of the type parameter and use the approach from the example; perhaps such a halfway house might be reasonable solution?
Would something like how the example works suffice, or would you need something more flexible?
Yeah; id need something more flexible.
types are hashable
Yeah; in my example I also meant for the 'attribute' lambda to return some hashable, 'type-like' information attribute of the object. A shape or dimension of an array would qualify, as being immutable type-like descriptors of the underlying raw array. The dtype of the array itself would qualify as well, obviously, even if not part of the 'type' on the level of the python type system.
But there are other use-cases for imbuing array elements with various type-like annotations beyond just their dtype; perhaps physical units, or in the case I have in mind, blades/grades of a geometric-algebra. Those hashable blade-descriptor objects would be quite involved objects themselves, with a variety of computable properties, on the basis of which we would want to specialize our function dispatch. That is, it would invite a richer syntax than just testing if the hashable attribute is inside some set, as per the typical type-based dispatch.
Hmm, perhaps I'm misunderstanding, but I think what you're after is possible in the setup of the example. E.g., what about something like the following?
from plum import dispatch, parametric, type_of
from plum.parametric import CovariantMeta
class DynamicTypeParameter:
"""A type parameter which contains a function `check` which can check whether
another type parameter is a type subparameter."""
def __init__(self, check, parameter=None):
self.check = check
self.parameter = parameter
def __str__(self):
return str(self.parameter)
def __repr__(self):
return repr(self.parameter)
_original_is_sub_type_parameter = CovariantMeta._is_sub_type_parameter
def _new_is_sub_type_parameter(cls, par_cls, subclass, par_subclass):
# Resolve dynamic type parameter for the parent class.
if len(par_cls) == 1 and isinstance(par_cls[0], DynamicTypeParameter):
par_check = par_cls[0].check
par_parameter = (par_cls[0].parameter,)
else:
par_check = None
# Resolve dynamic type parameter for the subclass.
if len(par_subclass) == 1 and isinstance(par_subclass[0], DynamicTypeParameter):
par_subclass = (par_subclass[0].parameter,)
if par_check:
# Use the dynamic check.
return par_check(*par_subclass)
else:
return _original_is_sub_type_parameter(cls, par_cls, subclass, par_subclass)
CovariantMeta._is_sub_type_parameter = _new_is_sub_type_parameter
@parametric
class DynamicallyTypedObject:
def __init__(self, obj):
self.obj = obj
@classmethod
def __infer_type_parameter__(cls, obj):
# Use the object itself as the type parameter.
return obj
@dispatch
def f(x: DynamicallyTypedObject[
DynamicTypeParameter(check=lambda p: hasattr(type(p), "__len__") and len(p) <= 10)
]):
print("Method for `len(x) <= 10`!")
import torch # Use PyTorch, because NumPy arrays cannot be hashed.
f(DynamicallyTypedObject(torch.ones(5)))
# Method for `len(x) <= 10`!
f(DynamicallyTypedObject(torch.ones(15)))
# NotFoundLookupError
We can add some sugar on top of this to make sure that it behaves like your proposal of @dispatch
:
import numpy as np
from functools import wraps
from plum import dispatch, parametric, type_of
from plum.parametric import CovariantMeta
class DynamicTypeParameter:
"""A type parameter which contains a function `check` which can check whether
another type parameter is a type subparameter."""
def __init__(self, check, parameter=None):
self.check = check
self.parameter = parameter
def __str__(self):
return str(self.parameter)
def __repr__(self):
return repr(self.parameter)
_original_is_sub_type_parameter = CovariantMeta._is_sub_type_parameter
def _new_is_sub_type_parameter(cls, par_cls, subclass, par_subclass):
# Resolve dynamic type parameter for the parent class.
if len(par_cls) == 1 and isinstance(par_cls[0], DynamicTypeParameter):
par_check = par_cls[0].check
par_parameter = (par_cls[0].parameter,)
else:
par_check = None
# Resolve dynamic type parameter for the subclass.
if len(par_subclass) == 1 and isinstance(par_subclass[0], DynamicTypeParameter):
par_subclass = (par_subclass[0].parameter,)
if par_check:
# Use the dynamic check.
return par_check(*par_subclass)
else:
return _original_is_sub_type_parameter(cls, par_cls, subclass, par_subclass)
CovariantMeta._is_sub_type_parameter = _new_is_sub_type_parameter
def attribute_dispatch(attribute, check):
def decorator(f):
@parametric
class _DynamicallyTypedObject:
def __init__(self, obj):
self.obj = obj
@classmethod
def __infer_type_parameter__(cls, obj):
return attribute(obj)
def wrapped_f(x: _DynamicallyTypedObject[DynamicTypeParameter(check)]):
return f(x)
wrapped_f.__name__ = f.__name__
wrapped_f = dispatch(wrapped_f)
@wraps(f)
def second_wrapped_f(x):
return wrapped_f(_DynamicallyTypedObject(x))
return second_wrapped_f
return decorator
def _safe_le(x, y):
try:
return x <= y
except Exception:
return False
@attribute_dispatch(attribute=lambda x: len(x), check=lambda x: _safe_le(x, 10))
def f(x):
print("Method for `len(x) <= 10`!")
f(np.ones(5))
# Method for `len(x) <= 10`!
f(np.ones(15))
# NotFoundLookupError
I dont fully appreciate the internal mechanisms (i suppose it requires some extra complexity to shoehorn it into the existing mechanisms of plum); but yeah the external API indeed looks the part!
With the minor detail that I think usually itd be nice to seperately bind the attribute to create a len/whatever_dispatcher since that part tends to be constant over a bunch of annotations usually I suppose.
With the minor detail that I think usually itd be nice to seperately bind the attribute to create a len/whatever_dispatcher since that part tends to be constant over a bunch of annotations usually I suppose.
That should certainly be possible!
I think I really like this idea and that it would be a valuable addition. However, you're indeed right that some precarious manoeuvring is required to fit it into the current internal mechanisms. There are still a few things unclear to me.
Firstly, to make this really work, the current mechanisms would require the ability to ask whether one check
function is more specific than another. E.g., lambda x: len(x) <= 5
is more specific than lambda x: len(x) <= 10
. One way to make this work is to not use lambda
s, but to use objects which can be compared. For example, make a Shape
object which is hashable such that any two shapes can be compared. This comes at the cost of boilerplate, but perhaps it's not too bad.
Secondly, in the prototype, the information about the shape is preserved by wrapping the object in another object with a type parameter. I think I still like the fundamental principle that all information necessary for dispatch should be retained when you take type(object)
.
What I would propose is an interface like the below. (Note that this is pseudo-code and doesn't actually run.)
from numpy import ndarray
from plum import dispatch, parametric
from plum.util import Comparable
class Shape(Comparable):
def __init__(self, *dims)
self.dims = dims
@dispatch
def __le__(self, other: "Shape"):
return ...
@parametric
class ShapedArray(ndarray):
def __infer_type_parameter__(cls, *args, **kw_args):
return Shape(*self.shape)
@dispatch
def f(x: ShapedArray[Shape(10, 10)]):
# Do something for objects with `shape <= Shape(10, 10)`.
@dispatch
def f(x: np.ndarray):
# Ensure that the shape information is always included.
return f(ShapedArray(x))
f(np.random.randn(15, 15))
What would you think about this proposal? Would that suffice for your use case?
I suppose it would suffice; but considering the number of different kinds of matching functions that I need, and the fact that they dont really have a ton of reuse; many are one-off, the need to wrap every condition into a type isnt very attractive.
While I like using nice libraries over reinventing the wheel, getting exactly the syntax I want takes me 40 lines of code added to my project.
That's totally fair enough.
getting exactly the syntax I want takes me 40 lines of code added to my project.
How would you deal with the below ambiguity?
@dispatch(condition=lambda x: len(x) <= 10)
def f(x):
print("This is an int of length at most 10")
@dispatch(condition=lambda x: len(x) <= 20)
def f(x):
print("This is an int of length at most 20")
If you call f
with an x
such that len(x) = 5
, both conditions would evaluate to True
. The key idea of multiple dispatch is that the most specific method should be chosen. However, unless you inspect these lambda
s, there is no way to know which of the conditions is most specific. EDIT: The wrapping into types, which admittedly isn't terribly attractive syntactically, forces you to chose an order which is then used to determine which is most specific.
What I currently do is to let the order in which the annotations are registered be defining for the order in which they are matched. With an optional kwarg to overload the insertion order; which might be useful when extending base library functionality.
Instead of maintaining an ordered list and letting that be defining... it might be better to define an explicit priority number. That way itd be easier to manage the situation where there could be multiple extension modules, and you dont want to end up with the risk of an import-order dependent result. You could just insert with fractional priority; and maybe raise an error in case of ambiguity?
I see! If that approach or assigning a priority number suffices, then perhaps it is not necessary to go through a lot of technical hoops for a fully general implementation of the idea. My sense is that determining which condition
function is more specific is the key difficulty (and also what lies at the heart of multiple dispatch), and a general implementation will necessarily involve at least a bit of extra boilerplate which specifies how any two condition
functions can be compared.
yeah... I dont know that there is a general answer to the question of how to implement such comparisons... the other option would be for the matching function to return a number, rather than a bool... that might work in some situations where you want this to be some complex estimate of the input (like maybe an estimate of runtime for a given implementation), but otherwise its just messy.
On a related note: if one is to use generic object attributes to decide multiple dispatch, its nice to encourage people to implement a flyweight pattern for those objects; to guarantee trivial equality comparisons between them:
class FlyweightMixin:
def __hash__(self):
return id(self)
def __eq__(self, other):
return self is other
class FlyweightFactory:
def __init__(self):
self.flyweight_pool = {}
def construct(self, key, value) -> FlyweightMixin:
"""All constructor calls of flyweights supposed to be made via this factory method."""
if key in self.flyweight_pool:
return self.flyweight_pool[key]
else:
value = value()
self.flyweight_pool[key] = value
return value
I guess its not a hard requirement that the attributes you dispatch on have a certain cost of value equality comparison... but if you care about performance it sure is nice. In my case these attributes contain potentially humongous numpy arrays that do not natively compare cheaply; and I do care about performance. If you end up going through with implementing something like this, might be useful to include something like this in the example. Id already hacked in similar functionality; but its nicer to factor out such a known pattern with an existing name; and it took me a while to realize it could be quite simple indeed.
This is the same thing python does under the hood to trivialize string comparisons and some other primitive types. Note that this a pretty simple/rough example... one nice feature to add would be for the flyweight mixin to sabotage the normal constructor, so your code cant accidentially create one outside the pool; cause thatd be a bug that would break the value equality implementation.
yeah... I dont know that there is a general answer to the question of how to implement such comparisons... the other option would be for the matching function to return a number, rather than a bool... that might work in some situations where you want this to be some complex estimate of the input (like maybe an estimate of runtime for a given implementation), but otherwise its just messy.
Hmm, yeah. It's a hard problem!
On a related note: if one is to use generic object attributes to decide multiple dispatch, its nice to encourage people to implement a flyweight pattern for those objects; to guarantee trivial equality comparisons between them:
This is a nice suggestion! Certainly a cheap performance boost in certain scenarios. If we were to go down the path of implementing dispatch based on object attributes, it would be worth thinking about how to make this as efficient as possible and also how to reduce boilerplate as much as possible.
@wesselb Am I right in thinking that support for this would be automatically included in the case that beartype dispatch is supported re: #53 ? See e.g. relevant beartype features. I'm pretty sure any of these types of situations could be checked with the new door
api and the beartype validators.
@tbsexton, combined with your proposal to perform dispatch solely using isinstance
checks, I think you might be right! This is exciting!
It would be nice (even though admittedly hacky) if we could dispatch on the number of dimensions of a jax/numpy array, which can be probed with
.ndim
on an instance.This is admittedly not part of the type information in python, but maybe we can hack it in?
I was thinking of creating a custom parametric type for signatures, but then the problem is resolving the call signature and bringing this information from the value domain to the type domain.
This happens in
parametric.py
if I am not mistaken, and would require changing this function to allow hacking in some custom types... Do you have any idea on what would be the best way to implement this?