numba / numba-scipy

numba_scipy extends Numba to make it aware of SciPy
https://numba-scipy.readthedocs.io/en/latest/
BSD 2-Clause "Simplified" License
258 stars 34 forks source link

Implement array-valued signatures #56

Open adeak opened 3 years ago

adeak commented 3 years ago

As of https://github.com/numba/numba-scipy/pull/54 the simplest scalar calls to jitted special functions should work.

However there's no support yet for array-valued inputs:

import numpy as np
from numba import njit
from scipy import special

x = np.linspace(-10, 10, 1000)

@njit
def jitted_j0(x):
    res = special.j0(x[0])  # works after PR #54
    # res = special.j0(x)  # breaks
    return res

print(jitted_j0(x))

This is not obviously a shortcoming, since looping in jitted functions should be alright. So this is just a mild suggestion to consider adding support for array-valued signatures. (This should probably be preceded with some benchmarks to see whether this would help anything performance-wise.)

brandonwillard commented 3 years ago

This is not obviously a shortcoming, since looping in jitted functions should be alright.

It's definitely a shortcoming, because the corresponding scipy.special functions that are being overloaded are ufuncs and do not have this limitation.

I would say that it doesn't render the library useless, though.

Anyway, I took a quick shot at using numba.vectorize on the functions produced by choose_kernel, but numba.extending.overload does not like the returned type of a numba.vectorize-wrapped function. Is that expected?

esc commented 3 years ago

Anyway, I took a quick shot at using numba.vectorize on the functions produced by choose_kernel, but numba.extending.overload does not like the returned type of a numba.vectorize-wrapped function. Is that expected?

do you have an example, perchance?

brandonwillard commented 3 years ago
modified   numba_scipy/special/overloads.py
@@ -10,7 +10,12 @@ def choose_kernel(name, all_signatures):
         for signature in all_signatures:
             if args == signature:
                 f = signatures.name_and_types_to_pointer[(name, *signature)]
-                return lambda *args: f(*args)
+
+                @numba.vectorize
+                def _f(*args):
+                    return f(*args)
+
+                return _f

     return choice_function

results in the following error:

E   numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
E   No implementation of function Function(<ufunc 'agm'>) found for signature:
E    
E    >>> agm(float64, float64)
E    
E   There are 2 candidate implementations:
E     - Of which 2 did not match due to:
E     Overload in function 'choose_kernel.<locals>.choice_function': File: ../code/python/numba-scipy/numba_scipy/special/overloads.py: Line 9.
E       With argument(s): '(float64, float64)':
E      Rejected as the implementation raised a specific error:
E        AssertionError: Implementator function returned by `@overload` has an unexpected type.  Got <numba._DUFunc '_f'>
E     raised from ~/envs/numba-scipy-env/lib/python3.7/site-packages/numba/core/typing/templates.py:742
E   
E   During: resolving callee type: Function(<ufunc 'agm'>)
E   During: typing of call at ~/code/python/numba-scipy/numba_scipy/tests/test_special.py (76)
E   
E   
E   File "numba_scipy/tests/test_special.py", line 76:
E       def numba_func(*args):
E           return scipy_func(*args)
E           ^

Is numba.extending.overload attempting to numba.jit the function returned by choose_kernel? The error looks similar to the one produced when numba.njit-ing a function wrapped with numba.vectorize.

brandonwillard commented 3 years ago

The varargs could also be a problem.

brandonwillard commented 3 years ago

I have a hack to get this working in my vectorized-overloads branch. It creates a fixed-arguments function on the fly to get past some apparent varargs issues with numba.vectorize.

If anyone knows how to get past this varargs issue without creating functions in this fashion—or any other fundamentally AST-based approach—please tell me, it would really help with the work we're doing in Aesara, as well.

stuartarchibald commented 3 years ago

There's no public extension API in Numba for declaring this in a simple manner, this sort of thing could be a work around.

from numba import njit, vectorize, types
from numba.extending import overload
import numpy as np
from numba import njit
from scipy import special

x = np.linspace(-10, 10, 1000)

# this is just a dummy scalar function cf. those in numba-scipy's wrapper for
# scipy.special.*, now #54 is in the standard overload for scalar j0 should
# just work.
@njit
def pretend_j0_from_cython(x):
    return x + 12.34

@vectorize
def vectorize_j0(x):
    return pretend_j0_from_cython(x)

# This gets the vectorization mechanics but will end up "hiding" the NumPy ufunc
@overload(special.j0)
def ol_beta(x):
    if isinstance(x, (types.Array, types.Number)):
        def impl(x):
            return vectorize_j0(x)
        return impl

@njit
def jitted_j0(x):
    res1 = special.j0(x[0])
    res2 = special.j0(x)
    return res1, res2

print(jitted_j0(x))
brandonwillard commented 3 years ago

The issue I ran into above is the signature for the @vectorized function: varargs wouldn't work, so I had to construct the function via compile/AST.

stuartarchibald commented 3 years ago

The issue I ran into above is the signature for the @vectorized function: varargs wouldn't work, so I had to construct the function via compile/AST.

Ah, I see, I misinterpreted this as not being able to register an overload with vectorize, and whilst that's a problem, I can see why *args failing is also a problem if you want to do that automatic generation!

Opened https://github.com/numba/numba/issues/6954 to track.

brandonwillard commented 3 years ago

Opened numba/numba#6954 to track.

Thanks for that; it's a problem that shows up in at least a couple places where we're trying to use Numba as a backend (e.g. here).

PabloRdrRbl commented 3 years ago

Hello, I have been able of using the workaround by @stuartarchibald . Is there any plan add this so there is no need to write the vectorized version of every function?

esc commented 3 years ago

@PabloRdrRbl I think a PR has already been opened: https://github.com/numba/numba-scipy/pull/58

PabloRdrRbl commented 2 years ago

There's no public extension API in Numba for declaring this in a simple manner, this sort of thing could be a work around.

from numba import njit, vectorize, types
from numba.extending import overload
import numpy as np
from numba import njit
from scipy import special

x = np.linspace(-10, 10, 1000)

# this is just a dummy scalar function cf. those in numba-scipy's wrapper for
# scipy.special.*, now #54 is in the standard overload for scalar j0 should
# just work.
@njit
def pretend_j0_from_cython(x):
    return x + 12.34

@vectorize
def vectorize_j0(x):
    return pretend_j0_from_cython(x)

# This gets the vectorization mechanics but will end up "hiding" the NumPy ufunc
@overload(special.j0)
def ol_beta(x):
    if isinstance(x, (types.Array, types.Number)):
        def impl(x):
            return vectorize_j0(x)
        return impl

@njit
def jitted_j0(x):
    res1 = special.j0(x[0])
    res2 = special.j0(x)
    return res1, res2

print(jitted_j0(x))

Is it possible to extend it to a function like jv, which takes two arguments?