Open adeak opened 3 years ago
This is not obviously a shortcoming, since looping in jitted functions should be alright.
It's definitely a shortcoming, because the corresponding scipy.special
functions that are being overloaded are ufunc
s and do not have this limitation.
I would say that it doesn't render the library useless, though.
Anyway, I took a quick shot at using numba.vectorize
on the functions produced by choose_kernel
, but numba.extending.overload
does not like the returned type of a numba.vectorize
-wrapped function. Is that expected?
Anyway, I took a quick shot at using
numba.vectorize
on the functions produced bychoose_kernel
, butnumba.extending.overload
does not like the returned type of anumba.vectorize
-wrapped function. Is that expected?
do you have an example, perchance?
modified numba_scipy/special/overloads.py
@@ -10,7 +10,12 @@ def choose_kernel(name, all_signatures):
for signature in all_signatures:
if args == signature:
f = signatures.name_and_types_to_pointer[(name, *signature)]
- return lambda *args: f(*args)
+
+ @numba.vectorize
+ def _f(*args):
+ return f(*args)
+
+ return _f
return choice_function
results in the following error:
E numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
E No implementation of function Function(<ufunc 'agm'>) found for signature:
E
E >>> agm(float64, float64)
E
E There are 2 candidate implementations:
E - Of which 2 did not match due to:
E Overload in function 'choose_kernel.<locals>.choice_function': File: ../code/python/numba-scipy/numba_scipy/special/overloads.py: Line 9.
E With argument(s): '(float64, float64)':
E Rejected as the implementation raised a specific error:
E AssertionError: Implementator function returned by `@overload` has an unexpected type. Got <numba._DUFunc '_f'>
E raised from ~/envs/numba-scipy-env/lib/python3.7/site-packages/numba/core/typing/templates.py:742
E
E During: resolving callee type: Function(<ufunc 'agm'>)
E During: typing of call at ~/code/python/numba-scipy/numba_scipy/tests/test_special.py (76)
E
E
E File "numba_scipy/tests/test_special.py", line 76:
E def numba_func(*args):
E return scipy_func(*args)
E ^
Is numba.extending.overload
attempting to numba.jit
the function returned by choose_kernel
? The error looks similar to the one produced when numba.njit
-ing a function wrapped with numba.vectorize
.
The varargs could also be a problem.
I have a hack to get this working in my vectorized-overloads
branch. It creates a fixed-arguments function on the fly to get past some apparent varargs issues with numba.vectorize
.
If anyone knows how to get past this varargs issue without creating functions in this fashion—or any other fundamentally AST-based approach—please tell me, it would really help with the work we're doing in Aesara, as well.
There's no public extension API in Numba for declaring this in a simple manner, this sort of thing could be a work around.
from numba import njit, vectorize, types
from numba.extending import overload
import numpy as np
from numba import njit
from scipy import special
x = np.linspace(-10, 10, 1000)
# this is just a dummy scalar function cf. those in numba-scipy's wrapper for
# scipy.special.*, now #54 is in the standard overload for scalar j0 should
# just work.
@njit
def pretend_j0_from_cython(x):
return x + 12.34
@vectorize
def vectorize_j0(x):
return pretend_j0_from_cython(x)
# This gets the vectorization mechanics but will end up "hiding" the NumPy ufunc
@overload(special.j0)
def ol_beta(x):
if isinstance(x, (types.Array, types.Number)):
def impl(x):
return vectorize_j0(x)
return impl
@njit
def jitted_j0(x):
res1 = special.j0(x[0])
res2 = special.j0(x)
return res1, res2
print(jitted_j0(x))
The issue I ran into above is the signature for the @vectorize
d function: varargs wouldn't work, so I had to construct the function via compile
/AST.
The issue I ran into above is the signature for the
@vectorize
d function: varargs wouldn't work, so I had to construct the function viacompile
/AST.
Ah, I see, I misinterpreted this as not being able to register an overload with vectorize
, and whilst that's a problem, I can see why *args
failing is also a problem if you want to do that automatic generation!
Opened https://github.com/numba/numba/issues/6954 to track.
Opened numba/numba#6954 to track.
Thanks for that; it's a problem that shows up in at least a couple places where we're trying to use Numba as a backend (e.g. here).
Hello, I have been able of using the workaround by @stuartarchibald . Is there any plan add this so there is no need to write the vectorized version of every function?
@PabloRdrRbl I think a PR has already been opened: https://github.com/numba/numba-scipy/pull/58
There's no public extension API in Numba for declaring this in a simple manner, this sort of thing could be a work around.
from numba import njit, vectorize, types from numba.extending import overload import numpy as np from numba import njit from scipy import special x = np.linspace(-10, 10, 1000) # this is just a dummy scalar function cf. those in numba-scipy's wrapper for # scipy.special.*, now #54 is in the standard overload for scalar j0 should # just work. @njit def pretend_j0_from_cython(x): return x + 12.34 @vectorize def vectorize_j0(x): return pretend_j0_from_cython(x) # This gets the vectorization mechanics but will end up "hiding" the NumPy ufunc @overload(special.j0) def ol_beta(x): if isinstance(x, (types.Array, types.Number)): def impl(x): return vectorize_j0(x) return impl @njit def jitted_j0(x): res1 = special.j0(x[0]) res2 = special.j0(x) return res1, res2 print(jitted_j0(x))
Is it possible to extend it to a function like jv
, which takes two arguments?
As of https://github.com/numba/numba-scipy/pull/54 the simplest scalar calls to jitted special functions should work.
However there's no support yet for array-valued inputs:
This is not obviously a shortcoming, since looping in jitted functions should be alright. So this is just a mild suggestion to consider adding support for array-valued signatures. (This should probably be preceded with some benchmarks to see whether this would help anything performance-wise.)