wesselb / lab

A generic interface for linear algebra backends
MIT License
70 stars 5 forks source link

Catching forgotten backend-specific imports #19

Open stoprightthere opened 6 months ago

stoprightthere commented 6 months ago

Hi,

We've been having (at geometric_kernels) some issues with people forgetting to import modules with definitions for backend-specific functions, as in import geometric_kernels.torch, for example. Currently the error messages are a bit cryptic in that case, and leave peope confused.

I think it would be useful to have a "catch-all" dispatcher that reminds the user to import the backend-specific code, if they try to run a dispatched function with numeric inputs. This should still raise NotFoundLookupError if the inputs are completely off.

I came up with my own half-baked attempt to this:

def promised_dispatch(error_msg, precedence=0):
    """
    Decorator for "promised" function. The implementation is not given yet, but it will be (e.g. in a separate module).
    """
    def wrap(f):
        _f = dispatch.abstract(f)
        signature = plum.Signature.from_callable(f, precedence=precedence)

        def _fallback_f(*args, **kwargs):
            if signature.match(args):
                raise RuntimeError(error_msg)

        _f.register(_fallback_f, signature, precedence)
        return _f
    return wrap

@promised_dispatch("Did you forget to do `import geometric_kernels.<backend>?")
def ff(x: B.Numeric):
    pass

@dispatch
def ff(x: B.NPNumeric):
    return x+1

ff(np.r_[3])    # returns 4

ff(torch.tensor([3]))  # raises "Did you forget?"

ff('3')  # raises NotFoundLookupError: `ff('3')` could not be resolved.

Do you think it'd be useful to incorporate something like this into lab? I think this will definitely reduce the amount of confusion.

wesselb commented 6 months ago

Hey @stoprightthere! Excellent suggestion.

LAB uses a single dispatch = Dispatcher() throughout the codebase. We could add a keyword hint:

dispatch = Dispatcher(hint="Did you import the right backend?")

This hint would then be appended to every NotFoundLookupError:

In [1]: import lab as B

In [2]: import torch

In [3]: B.randn(torch.float32)
<traceback>
NotFoundLookupError: `randn(torch.float32)` could not be resolved.

Closest candidates are the following:
    randn(dtype: typing.Union[type, numpy.dtype], *shape: typing.Union[int, lab.shape.Dimension, numpy.int8, numpy.int16, numpy.int32, numpy.int64, numpy.uint8, numpy.uint16, numpy.uint32, numpy.uint64])
        precedence=1
        <function randn at 0x7f94f06313a0> @ ~/Dropbox/Projects/PyLib/LAB/lab/shape.py:50
    randn(*shape: typing.Union[int, lab.shape.Dimension, numpy.int8, numpy.int16, numpy.int32, numpy.int64, numpy.uint8, numpy.uint16, numpy.uint32, numpy.uint64])
        <function randn at 0x7f94f0641310> @ ~/Dropbox/Projects/PyLib/LAB/lab/random.py:161
    randn(ref: typing.Union[int, lab.shape.Dimension, numpy.int8, numpy.int16, numpy.int32, numpy.int64, numpy.uint8, numpy.uint16, numpy.uint32, numpy.uint64, bool, numpy.bool_, float, numpy.float16, numpy.float32,
    numpy.float64, numpy.float128, complex, numpy.complex64, numpy.complex128, numpy.complex256, numpy.ndarray, plum.type.ModuleType[autograd.tracer.Box], plum.type.ModuleType[tensorflow.Tensor],
    plum.type.ModuleType[tensorflow.Variable], plum.type.ModuleType[tensorflow.IndexedSlices], plum.type.ModuleType[keras.KerasTensor], plum.type.ModuleType[jaxlib.xla_extension.ArrayImpl],
    plum.type.ModuleType[jax.core.Tracer], torch.Tensor])
        <function randn at 0x7f94f0641550> @ ~/Dropbox/Projects/PyLib/LAB/lab/random.py:171

Did you import the right backend?

How would this look like?


The suggested methods, by the way, would become more readable with plum.activate_union_aliases():

In [4]: from plum import activate_union_aliases

In [5]: activate_union_aliases()

In [6]: B.randn(torch.float32)
<traceback>
NotFoundLookupError: `randn(torch.float32)` could not be resolved.

Closest candidates are the following:
    randn(dtype: typing.Union[B.NPDType], *shape: typing.Union[B.Int])
        precedence=1
        <function randn at 0x7f94f06313a0> @ ~/Dropbox/Projects/PyLib/LAB/lab/shape.py:50
    randn(*shape: typing.Union[B.Int])
        <function randn at 0x7f94f0641310> @ ~/Dropbox/Projects/PyLib/LAB/lab/random.py:161
    randn(ref: typing.Union[B.Number, B.JAXRandomState, B.AGNumeric, B.TFNumeric, torch.Tensor])
        <function randn at 0x7f94f0641550> @ ~/Dropbox/Projects/PyLib/LAB/lab/random.py:171

Did you import the right backend?

Unfortunately, this is an experimental feature which I wouldn't activate by default.