numba / numba

NumPy aware dynamic Python compiler using LLVM
https://numba.pydata.org/
BSD 2-Clause "Simplified" License
9.93k stars 1.13k forks source link

Speedups for jitted reductions #2176

Open ml31415 opened 8 years ago

ml31415 commented 8 years ago

I was playing around with code like this:

import numpy as np
import numba as nb
import bottleneck as bn
import numbagg as nbg

@nb.njit
def nanmin_demo(x):
    if x.size == 0:
        raise ValueError("nanmin(): empty array")
    ret = np.nan
    for i, x_ in enumerate(x.flat):
        if not np.isnan(x_):
            ret = x_
            break
    if np.isnan(ret):
        return ret

    # This should be x.flat[i:], so that the array is
    # not iterated again unnecessarily. Better ideas?
    for x_ in x.flat:
        if x_ < ret:
            ret = x_
    return ret

@nb.njit
def nanmin_numba(a):
    if a.size == 0:
        raise ValueError("nanmin(): empty array")
    for view in np.nditer(a):
        minval = view.item()
        break
    for view in np.nditer(a):
        v = view.item()
        if not minval < v and not np.isnan(v):
            minval = v
    return minval

@nb.njit
def nanmin_numbagg_1dim(a):
    amin = np.infty
    all_missing = 1
    for ai in a.flat:
        if ai <= amin:
            amin = ai
            all_missing = 0
    if all_missing:
        amin = np.nan
    return amin

impls = [nanmin_demo, nanmin_numba, nanmin_numbagg_1dim, nbg.nanmin, bn.nanmin]

for i in range(4):
    x = np.random.random(100000)
    x[x>0.3334*i] = np.nan
    res = np.array([impl(x) for impl in impls])
    assert np.all(res[0] == res) or np.all(np.isnan(res))
    x.reshape((-1, 100))
    res = np.array([impl(x) for impl in impls])
    assert np.all(res[0] == res) or np.all(np.isnan(res))
    for impl in impls:
        %timeit impl(x)
    print '--------'
print nb.__version__

nanmin_impl is the currently imlemented overload. It's compared with the inbuilt numpy version, bottleneck and another experimental implementation. The timing looks like this for me:

10000 loops, best of 3: 63.5 µs per loop
10000 loops, best of 3: 154 µs per loop
10000 loops, best of 3: 62.8 µs per loop
10000 loops, best of 3: 151 µs per loop
10000 loops, best of 3: 68.4 µs per loop
--------
10000 loops, best of 3: 50.9 µs per loop
1000 loops, best of 3: 405 µs per loop
10000 loops, best of 3: 57.3 µs per loop
10000 loops, best of 3: 151 µs per loop
10000 loops, best of 3: 68.4 µs per loop
--------
10000 loops, best of 3: 51.1 µs per loop
1000 loops, best of 3: 409 µs per loop
10000 loops, best of 3: 57.4 µs per loop
10000 loops, best of 3: 151 µs per loop
10000 loops, best of 3: 68.5 µs per loop
--------
10000 loops, best of 3: 51.6 µs per loop
10000 loops, best of 3: 126 µs per loop
10000 loops, best of 3: 57.5 µs per loop
10000 loops, best of 3: 151 µs per loop
10000 loops, best of 3: 68.5 µs per loop
--------
0.28.1
10000 loops, best of 3: 59.4 µs per loop
10000 loops, best of 3: 153 µs per loop
10000 loops, best of 3: 192 µs per loop
1000 loops, best of 3: 207 µs per loop
10000 loops, best of 3: 68.5 µs per loop
--------
10000 loops, best of 3: 102 µs per loop
1000 loops, best of 3: 406 µs per loop
10000 loops, best of 3: 192 µs per loop
1000 loops, best of 3: 207 µs per loop
10000 loops, best of 3: 68.4 µs per loop
--------
10000 loops, best of 3: 102 µs per loop
1000 loops, best of 3: 408 µs per loop
10000 loops, best of 3: 192 µs per loop
1000 loops, best of 3: 208 µs per loop
10000 loops, best of 3: 68.4 µs per loop
--------
10000 loops, best of 3: 102 µs per loop
10000 loops, best of 3: 126 µs per loop
10000 loops, best of 3: 192 µs per loop
1000 loops, best of 3: 206 µs per loop
10000 loops, best of 3: 68.6 µs per loop
--------
0.29.0

As you can see, for some unlucky cases, the current implementation is 8x slower than nanmin_demo. About half of this speedup comes from using x.flat instead of nd.iter. I'm actually not sure, if there are some good reasons to use nd.iter instead of x.flat? If so, I'd be curious to learn, cause it would also affect a bunch of code that I wrote for numpy_groupies.

The other major speedup seems to be some unlucky branching being done when if is fed with more than a simple nan-check or comparison. The jitted code seems to suffer much more from adding compound if-statements than ordinary C code. Any ideas why that is?

If it should actually be valid, to use the flatiter for these cases, I'd go ahead and put some more optimizations together for a pull request.

gdementen commented 7 years ago

You might want to have a look at https://github.com/shoyer/numbagg

pitrou commented 7 years ago

Well, nanmin_impl() doesn't use the same algorithm as nanmin(), so you may be comparing apples to oranges here.

ml31415 commented 7 years ago

@gdementen Thanks for the note, I added it to the benchmark. Looks like numbagg doesn't use faster functions on flat arrays due to #1087.

One more strange thing: This benchmark was taken with 0.28.1. When I upgraded to 0.29.0, the timings got up to around 100µs for nanmin. The numbagg implementation also suffers a lot. I added the timings above. Any ideas on that?

@pitrou I just make sure to get a not-nan value beforehand, so I can rely on evaluating to False in case I find one within the actual loop, and don't have to check for nan over and over again explicitly. In the end it does quite the same I guess.

ml31415 commented 7 years ago

I guess we maybe should split this up into the speed regression, and further improvements on top of that.

pitrou commented 7 years ago

In the end it does quite the same I guess.

My point is that you cannot compare nditer() and flat performance if you don't use them in the exact same way. Ideally nditer() should always be as fast as flat (and faster for non-C arrays, since it is allowed to walk the array in non-logical order).

ml31415 commented 7 years ago

My main question is, is there any important reason not to simply iterate over the plain array? All functions currently just treat the input array as 1d, so it wouldn't matter at all, in which order the array is iterated. Using nditer imho just adds some overhead that case. I extended the benchmarking a bit in this gist Here are the current results:

0.28.1
                  numpy     numba     new      
---------------------------------------------
nanmin    nans    0.04066   0.05641   0.03758  
nanmin    nonans  0.04073   0.05690   0.03318  
---------------------------------------------
nanmax    nans    0.03982   0.06563   0.03342  
nanmax    nonans  0.03981   0.05731   0.03394  
---------------------------------------------
nansum    nans    0.05121   0.05705   0.03537  
nansum    nonans  0.04213   0.05773   0.03382  
---------------------------------------------
nanmean   nans    0.04621   0.05895   0.03313  
nanmean   nonans  0.05358   0.06331   0.06475  
---------------------------------------------
nanvar    nans    0.06668   0.09916   0.04327  
nanvar    nonans  0.07438   0.09794   0.05085  
---------------------------------------------
nanstd    nans    0.04514   0.05655   0.03473  
nanstd    nonans  0.04121   0.05822   0.03327  
---------------------------------------------
nanmedian nans    0.04447   0.05676  
nanmedian nonans  0.03985   0.05692  
---------------------------------------------
all       int     0.05230   0.05760   0.03325  
all       float   0.04987   0.06085   0.03349  
---------------------------------------------
any       int     0.05157   0.05700   0.03399  
any       float   0.04824   0.05727   0.03319 

As the table shows, for nearly all cases there seem mentionable speedups possible.

ml31415 commented 7 years ago

As these functions are real bread an butter number crunching tools, I guess it really make sense to have them as fast as possible for every user. I'd volunteer to do the legwork for the speedups, though it would be nice to have a definitive answer about the use of nditer first.

gdementen commented 7 years ago

I never even looked at the implementation of these functions in numba, but nditer might be used to reuse the same code for the case where the axis argument is provided. Just a thought...

pitrou commented 7 years ago

@ml31415, please try your benchmark with Fortran-ordered arrays.

ml31415 commented 7 years ago

@pitrou I understand that the flatiter has to put the fortran arrays upside down in that case and that it will cost performance. Though also nditer seems to bring some overhead, which I'd like to avoid for simple 1d cases. What I'd be looking for is a way to read the array as is from memory, whichever order comes in. a+b == b+a in the end. For 1d-arrays I can simply iter for x in y, without .flat, ravel() or .nditer(), but for nd-arrays the trouble starts. One option would be to use different functions for 1d/nd, but I had hoped for better suggestions to force the array to be seen as 1d and iterated as such.

pitrou commented 7 years ago

What I'd be looking for is a way to read the array as is from memory

Well, that's exactly what nditer() does :-) If there is a regression in nditer() performance it should be investigated (perhaps @stuartarchibald wants to take a look).

ml31415 commented 7 years ago

Hmm, for some loops nditer() actually comes quite close to the plain iter speed, though for some others it clearly falls behind. I suppose it may be compiler optimizations, that sometimes work, sometimes not. What I could have imagined is smth like that:

def get_flatiter(a):
    if len(a.shape) > 1:
        @register_jittable
        def flatiter(arr):
            return np.nditer(arr)
    else:
        @register_jittable
        def flatiter(arr):
            return arr

@overload(np.nansum)
def np_nansum(a):
    if not isinstance(a, types.Array):
        return
    if isinstance(a.dtype, types.Integer):
        retty = types.intp
    else:
        retty = a.dtype
    zero = retty(0)
    isnan = get_isnan(a.dtype)
    flatiter = get_flatiter(a)

    def nansum_impl(arr):
        c = zero
        for v in flatiter(arr):
            if not isnan(v):
                c += v
        return c

    return nansum_impl
stuartarchibald commented 7 years ago

I'll take a look when I next have some spare cycles. Something appears to have caused a performance regression in https://github.com/numba/numba/issues/2196, and I guess potentially this too. Looking at the optimised asm dump from #2196 , 1) there's a lot more instructions present 2) the selected instructions seem generally wider, but at the same time the LLVM backend wasn't updated as far as I'm aware. As discussed with @pitrou offline, I'll do a bisect when I get a chance.

stuartarchibald commented 7 years ago

Finally got around to bisecting this.

d6d14642f9642ef4337e5034ff0b9ffe29fb53ba is the first bad commit

which is d6d1464 from PR #2050