Open stoprightthere opened 6 months ago
Hey @stoprightthere! Excellent suggestion.
LAB uses a single dispatch = Dispatcher()
throughout the codebase. We could add a keyword hint
:
dispatch = Dispatcher(hint="Did you import the right backend?")
This hint would then be appended to every NotFoundLookupError
:
In [1]: import lab as B
In [2]: import torch
In [3]: B.randn(torch.float32)
<traceback>
NotFoundLookupError: `randn(torch.float32)` could not be resolved.
Closest candidates are the following:
randn(dtype: typing.Union[type, numpy.dtype], *shape: typing.Union[int, lab.shape.Dimension, numpy.int8, numpy.int16, numpy.int32, numpy.int64, numpy.uint8, numpy.uint16, numpy.uint32, numpy.uint64])
precedence=1
<function randn at 0x7f94f06313a0> @ ~/Dropbox/Projects/PyLib/LAB/lab/shape.py:50
randn(*shape: typing.Union[int, lab.shape.Dimension, numpy.int8, numpy.int16, numpy.int32, numpy.int64, numpy.uint8, numpy.uint16, numpy.uint32, numpy.uint64])
<function randn at 0x7f94f0641310> @ ~/Dropbox/Projects/PyLib/LAB/lab/random.py:161
randn(ref: typing.Union[int, lab.shape.Dimension, numpy.int8, numpy.int16, numpy.int32, numpy.int64, numpy.uint8, numpy.uint16, numpy.uint32, numpy.uint64, bool, numpy.bool_, float, numpy.float16, numpy.float32,
numpy.float64, numpy.float128, complex, numpy.complex64, numpy.complex128, numpy.complex256, numpy.ndarray, plum.type.ModuleType[autograd.tracer.Box], plum.type.ModuleType[tensorflow.Tensor],
plum.type.ModuleType[tensorflow.Variable], plum.type.ModuleType[tensorflow.IndexedSlices], plum.type.ModuleType[keras.KerasTensor], plum.type.ModuleType[jaxlib.xla_extension.ArrayImpl],
plum.type.ModuleType[jax.core.Tracer], torch.Tensor])
<function randn at 0x7f94f0641550> @ ~/Dropbox/Projects/PyLib/LAB/lab/random.py:171
Did you import the right backend?
How would this look like?
The suggested methods, by the way, would become more readable with plum.activate_union_aliases()
:
In [4]: from plum import activate_union_aliases
In [5]: activate_union_aliases()
In [6]: B.randn(torch.float32)
<traceback>
NotFoundLookupError: `randn(torch.float32)` could not be resolved.
Closest candidates are the following:
randn(dtype: typing.Union[B.NPDType], *shape: typing.Union[B.Int])
precedence=1
<function randn at 0x7f94f06313a0> @ ~/Dropbox/Projects/PyLib/LAB/lab/shape.py:50
randn(*shape: typing.Union[B.Int])
<function randn at 0x7f94f0641310> @ ~/Dropbox/Projects/PyLib/LAB/lab/random.py:161
randn(ref: typing.Union[B.Number, B.JAXRandomState, B.AGNumeric, B.TFNumeric, torch.Tensor])
<function randn at 0x7f94f0641550> @ ~/Dropbox/Projects/PyLib/LAB/lab/random.py:171
Did you import the right backend?
Unfortunately, this is an experimental feature which I wouldn't activate by default.
Hi,
We've been having (at geometric_kernels) some issues with people forgetting to import modules with definitions for backend-specific functions, as in
import geometric_kernels.torch
, for example. Currently the error messages are a bit cryptic in that case, and leave peope confused.I think it would be useful to have a "catch-all" dispatcher that reminds the user to import the backend-specific code, if they try to run a dispatched function with numeric inputs. This should still raise NotFoundLookupError if the inputs are completely off.
I came up with my own half-baked attempt to this:
Do you think it'd be useful to incorporate something like this into lab? I think this will definitely reduce the amount of confusion.