numba / numba-scipy

numba_scipy extends Numba to make it aware of SciPy
https://numba-scipy.readthedocs.io/en/latest/
BSD 2-Clause "Simplified" License
256 stars 33 forks source link

Casting for special functions #14

Open person142 opened 5 years ago

person142 commented 5 years ago

Currently numba_scipy.special does no casting for special function arguments like the special ufuncs do. A basic example is:

>>> import numba
>>> import numpy as np
>>> import scipy.special as sc
>>> import numba_scipy.special
>>> @numba.njit
... def gammaln(x):
...     return sc.gammaln(x)
...
>>> sc.gammaln(1.0)
0.0
>>> sc.gammaln(np.float32(1.0))
0.0
>>> gammaln(1.0)
0.0
>>> gammaln(np.float32(1.0))
Traceback (most recent call last):
...

(Note that most functions in special don't have specific float32 signatures, they just cast to float64 and then cast back.)

This issue is to discuss: how should we handle the casting?

For numba_special I had a branch that handled this in the following way:

It works but is maybe a little messy. Is there a better way to handle this?

person142 commented 5 years ago

After thinking about this for a while, I think it’s probably better to just explicitly generate the casted signatures at codegen time and add them to signatures.py. Keeps the runtime overloading logic simple.

I’ll submit a PR soon(ish).

stuartarchibald commented 5 years ago

Would it work to just add the cast at the function call site? I suppose the challenge is going to be when there's multiple matching implementations that might work.

stuartarchibald commented 5 years ago

Also, looking forward to the PR, thanks! I'm hoping to ship a 0.2.0 with Numba 0.46 shortly.

person142 commented 5 years ago

I suppose the challenge is going to be when there's multiple matching implementations that might work.

Yeah you’d have to make sure the signatures are sorted by specificity and check for exact matches first.

In that scenario I’m also not sure how you can generate the casting function on the fly from the types in a way that compiles to efficient machine code. Quite possibly a lack of knowledge on my part though.

If you can’t generate them on the fly, then you have to codegen them, and at that point you might as well just generate the full signature. (That’s been my train of thought at least.)

stuartarchibald commented 5 years ago

IIRC in the Numba code base, the ufuncs just have a large spelled out maps of what's accepted, there may be no better way. I suppose however in this case the Numba type system can be leaned upon a bit so as to do something like:

Suppose there's a special function foo with two bindings:

Something along the lines of:

@overload(foo)
def ol_foo(x, y):
  if not isinstance(y, types.Float): # y must be a float type
    return None

  if isinstance(x, types.Float):
    def impl(x, y):
      return foo(types.float64(x), types.float64(y))
    return impl
  elif isinstance(x, types.Integer):
    def impl(x, y):
      return foo(types.long_(x), types.float64(y))
    return impl

might work and could be ok to generate? This may take some iterating to get right but it'd be useful to establish a working pattern.