wesselb / lab

A generic interface for linear algebra backends
MIT License
70 stars 5 forks source link

Function registration with lab + plum 2.0 #6

Open InfProbSciX opened 1 year ago

InfProbSciX commented 1 year ago

With the previous versions of plum and lab (backends==1.4.32 plum-dispatch==1.7.4), the following was a way in which one could register a new set of types for a function:

import lab as B
from plum import Signature, Union
import scipy.sparse as sp

SparseArray = Union(
    sp.bsr_matrix,
    # ...
    alias="SparseArray",
)
_SparseArraySign = Signature(SparseArray)

def sparse_transpose(a):
    return a.T

B.T.register(_SparseArraySign, sparse_transpose)

B.T(sp.csr_array([0, 0]))  # !

This doesn't work anymore. For starters, I've fixed the way Union works now:

SparseArray = Union[sp.bsr_matrix]

but the usage of B.T still fails:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In [6], line 1
----> 1 B.T(sp.csr_array([0, 0]))

File ~/miniconda3/lib/python3.9/site-packages/plum/function.py:342, in Function.__call__(self, *args, **kw_args)
    338 def __call__(self, *args, **kw_args):
    339     # Before attempting to use the cache, resolve any unresolved registrations. Use
    340     # an `if`-statement to speed up the common case.
    341     if self._pending:
--> 342         self._resolve_pending_registrations()
    344     # Attempt to use the cache based on the types of the arguments.
    345     types = tuple(map(type, args))

File ~/miniconda3/lib/python3.9/site-packages/plum/function.py:224, in Function._resolve_pending_registrations(self)
    220     signature = extract_signature(f, precedence=precedence)
    221 else:
    222     # Ensure that the implementation is `f`, but make a copy before
    223     # mutating.
--> 224     signature = signature.__copy__()
    225     signature.implementation = f
    227 # Ensure that the implementation has the right name, because this name
    228 # will show up in the docstring.

AttributeError: 'function' object has no attribute '__copy__'

Is there a fix for this, and could the documentation be updated on how one updates such functions?

This is from GeometricKernels.

wesselb commented 1 year ago

Hey @InfProbSciX! I’m very sorry this is broken now. Just sending a quick message here to let you know that I’ve seen this and will reply very soon with how this now works.

wesselb commented 1 year ago

Hey @InfProbSciX. I believe that your snippet should work with minimal changes:

from typing import Union

import lab as B
import scipy.sparse as sp
from plum import Signature

SparseArray = Union[
    sp.bsr_matrix,
    int,  # Dummy for other types
]
_SparseArraySign = Signature(SparseArray)

def sparse_transpose(a):
    return a.T

B.T.register(sparse_transpose, _SparseArraySign)

B.T(sp.bsr_array([0, 0]))  # OK
B.T(sp.csr_array([0, 0]))  # error

Is this what you're after?

Note that aliasing unions has been removed, because Plum now runs fully on standard type hints. Nevertheless, it is still possible to achieve the desired behaviour, which is documented here. This would work as follows:

>>> from plum import activate_union_aliases, set_union_alias

>>> activate_union_aliases()

>>> set_union_alias(SparseArray, "SparseArray")

>>> SparseArray
typing.Union[SparseArray]

This feature must be activated explicitly because it patches typing.Union.__repr__ and might therefore cause unintended behaviour.