mdbartos / pysheds

:earth_americas: Simple and fast watershed delineation in python.
GNU General Public License v3.0
722 stars 196 forks source link

Allow Numba to perform polymorphic dispatching #230

Open groutr opened 1 year ago

groutr commented 1 year ago

Strictly typing each function makes numba unable to compile specialized functions for smaller types. This negatively impacts both performance and memory usage of the numba code.

An example of this is _par_get_candidates_numba called in the resolve_flats function. I have a fairly large test case dem of shape (11431, 11292). Reading from the raster produces an array of float32 values (the native dtype of the dem). In the resolve_flats function, we construct an array called insides which is an array of indexes which naturally fit in the int32 datatype. Because of the typing of _par_get_candidates_numba, numba only accepts arrays of float64 and int64, doubling the amount of memory for each array.

Removing the type annotations allows numba to be smarter with how the function is compiled. It can specialize a (float32, int32) version that runs in about half the time than the (float64, int64) version and uses half the memory. This translates to being able to more quickly process larger DEMs.

# have to allocate new arrays for the inputs which are 2x larger than each original array.
%timeit _par_get_candidates_numba_fixed(dem64, inside64)
1.09 s ± 65.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

# Less memory and faster!
%timeit _par_get_candidates_numba_flexible(dem, inside)
442 ms ± 7.59 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

You can read more about numba's dispatching here: https://numba.readthedocs.io/en/stable/developer/dispatching.html

mdbartos commented 1 year ago

Thanks for pointing this out. I initially started with numba's polymorphic dispatching but moved to strict types later on. There are advantages offered by strict typing like ease of debugging. For some users (like myself), the JIT compilation overhead also starts to add up over multiple Python sessions and can be somewhat frustrating when experimenting in an interactive context.

For performance, my inclination would be to export multiple versions of each numba function as shown below and then infer types from the source data: https://numba.readthedocs.io/en/stable/user/pycc.html#standalone-example

groutr commented 1 year ago

@mdbartos I'm not sure if you noticed the note in the documentation about the pending deprecation of AOT compilation. I can understand the ease of debugging, however the inputs to most of these functions are already strictly typed, homogeneous arrays. The main request here is to allow using the smaller native dtypes of the source data. The compiled functions can still be cached, so the comment about compilation overhead adding up doesn't really make sense to me.

mdbartos commented 1 year ago

Hi @groutr, in my experience numba still recompiles the function each time unless types are specified, even when cache=True.

Try running the following code sample, then resetting the kernel and trying again. The untyped one incurs compilation overhead on subsequent runs while the typed one does not:

from numba import njit
import numpy as np
from numba.types import float64

@njit(float64(float64[:]), cache=True)
def norm_squared_typed(vec):
    n = len(vec)
    result = 0.
    for i in range(n):
        result += vec[i]**2
    return result

@njit(cache=True)
def norm_squared_untyped(vec):
    n = len(vec)
    result = 0.
    for i in range(n):
        result += vec[i]**2
    return result

vec = np.arange(10, dtype=np.float64)

%time norm_squared_typed(vec)
%time norm_squared_untyped(vec)
CPU times: user 150 µs, sys: 478 µs, total: 628 µs
Wall time: 80.3 µs
CPU times: user 9.54 ms, sys: 45.7 ms, total: 55.2 ms
Wall time: 6.04 ms
groutr commented 1 year ago

@mdbartos I think there is a little bit of misunderstanding of what numba caching is actually doing. Your benchmark isn't measuring what you think it's measuring.

After playing around with your example, what I think is happening is: When you define norm_squared_type, because the types are completely defined, numba can compile and cache the function at the function definition time. This means by the the time you run your timing statement, it has already been compiled. When you reset your kernel, the function is still be recompiled when you define the function. It is simply a matter of the compilation happening somewhere other than you expect. On the other hand, because norm_squared_untyped has to use lazy compilation, it cannot be compiled at function definition because there is not enough type information. The function uses the types of the first call to specialize the function and compile and cache it. All subsequent calls with the same calling signature use the cached version. If you call the function with different types, that will compile a new version of the function for those types and cache it as well.

When you reset the kernel, the both functions get recompiled. It's just one function gets compiled much earlier, when you define the function, making it appear to be faster than the lazy compiled function on the first call. In reality, they take the same time to complete. You can see this happening if you turn on the numba cache debugging.

mdbartos commented 1 year ago

Thanks @groutr for the explanation. I re-tested with different breakpoints for timing and that appears to be correct.

I would be open to removing strict typing, but it would have to be done carefully. Maybe as part of a longer term release.