scikit-hep / awkward

Manipulate JSON-like data with NumPy-like idioms.
https://awkward-array.org
BSD 3-Clause "New" or "Revised" License
826 stars 85 forks source link

Numpy operations over arrays of records-fields in the Numba context is not working #1282

Closed tamasgal closed 2 years ago

tamasgal commented 2 years ago

Version of Awkward Array

1.7.0

Description and code to reproduce

I need to work with complex algorithms on chunks of Awkward Arrays (containing nested arrays of records) and boiled down one of my main issues to a very MWE. I have the feeling that I am blind doing something wrong because this is fairly basic but I was not able to find answers in the docs. I tend to believe that this is kind of a bug or at least a usability-issue, so I decided to open an issue instead of a discussion.

The https://awkward-array.org/how-to-use-in-numba-features.html are still work in progress (#1064) but the API documentation mentions that records are supported in the Numba context. The ak.numba.register() function did not help (mentioned in https://awkward-array.readthedocs.io/en/latest/ak.numba.register.html?highlight=numba%20record).

Here is the simple example:

import awkward as ak # v1.7.0
import numpy as np  # v1.21.5
import numba as nb  # v0.55.1

arr = ak.Array([{"a": 1}, {"a": 2}, {"a": 3}])

@nb.njit
def foo(arr):
    return np.sum(arr.a)

foo(arr)

which gives

---------------------------------------------------------------------------
TypingError                               Traceback (most recent call last)
Input In [100], in <module>
      7 @nb.njit
      8 def foo(arr):
      9     return np.sum(arr.a)
---> 11 foo(arr)

File ~/Dev/km3io/venv/lib/python3.9/site-packages/numba/core/dispatcher.py:468, in _DispatcherBase._compile_for_args(self, *args, **kws)
    464         msg = (f"{str(e).rstrip()} \n\nThis error may have been caused "
    465                f"by the following argument(s):\n{args_str}\n")
    466         e.patch_message(msg)
--> 468     error_rewrite(e, 'typing')
    469 except errors.UnsupportedError as e:
    470     # Something unsupported is present in the user code, add help info
    471     error_rewrite(e, 'unsupported_error')

File ~/Dev/km3io/venv/lib/python3.9/site-packages/numba/core/dispatcher.py:409, in _DispatcherBase._compile_for_args.<locals>.error_rewrite(e, issue_type)
    407     raise e
    408 else:
--> 409     raise e.with_traceback(None)

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<function sum at 0x108261ca0>) found for signature:

 >>> sum(ak.ArrayView(ak.NumpyArrayType(array(int64, 1d, A), none, {}), None, ()))

There are 2 candidate implementations:
      - Of which 2 did not match due to:
      Overload in function 'Numpy_method_redirection.generic': File: numba/core/typing/npydecl.py: Line 379.
        With argument(s): '(ak.ArrayView(ak.NumpyArrayType(array(int64, 1d, A), none, {}), None, ()))':
       Rejected as the implementation raised a specific error:
         TypeError: array does not have a field with key 'sum'

       (https://github.com/scikit-hep/awkward-1.0/blob/1.7.0/src/awkward/_connect/_numba/layout.py#L341)
  raised from /Users/tamasgal/Dev/km3io/venv/lib/python3.9/site-packages/awkward/_connect/_numba/layout.py:339

During: resolving callee type: Function(<function sum at 0x108261ca0>)
During: typing of call at /var/folders/84/mcvklq757tq1nfrkbxvvbq8m0000gn/T/ipykernel_6479/2151303190.py (9)

File "../../../../../var/folders/84/mcvklq757tq1nfrkbxvvbq8m0000gn/T/ipykernel_6479/2151303190.py", line 9:
<source missing, REPL/exec in use?>

It also fails for sum() and I even tried arr.to_numpy() / ak.to_numpy(arr) inside the JITted function but those functions are not present in the Numba context.

Of course, I know that one can do super efficient high-level operations outside of Numba, but I am really forced to do nested loops and AwkwardArray builders, which btw. I already do successfully combined with Numba and those work fine but so far no records were involved, so I did not encounter this issue yet.

agoose77 commented 2 years ago

At the moment, we don't support NumPy operations on Awkward Arrays inside of Numba jitted functions. @jpivarski might correct me here, but we do support converting 1D Awkward arrays to NumPy arrays inside of numba functions, i.e

@nb.njit
def sum_1d(x):
    y = np.asarray(x)
    return np.sum(y)

sum_1d(ak.Array([1, 2, 3]))

I took a quick look, and I don't think we support asarray for regular n-dim arrays; we only implement the typer for the 1D case

The solution for your problem really depends upon the algorithm that you're trying to implement. In general, it's better to use bare loops in Numba, because IIRC NumPy operations (even in Numba) create temporary arrays that Numba is unable to elide.

Note that the ArrayBuilder is a useful helper, but will slow your code down vs the "most performant approaches". This is only something to worry about if your code is too slow (however that is defined)! In general, the way to use Numba with Awkward is to do the structure preparation outside, and the tight-loops inside the Numba function. This includes handling structure (e.g. ak.unflatten) outside of Numba. In the ideal case, you can pass in the result array(s) directly, and handle the structure later on, e.g.

@nb.njit
def _sum_last_2d(array, result):
    for i, inner in enumerate(array):
        x = 0.0
        for y in inner:
            x += y
        result[i] = x

my_jagged_array = ak.Array([
    [
        [
            1, 2, 3, 4
        ]
    ],
    [
    ]
])

# Flatten array to make a 2D array of arrays (with varying lengths)
_sum_2d_input = ak.flatten(my_jagged_array, axis=1)
_sum_2d_output = np.empty(len(_sum_2d_input))

# Perform sum over last axis
_sum_last_2d(_sum_2d_input, _sum_2d_output)

# Unflatten result to restore jagged structure
result = ak.unflatten(_sum_2d_output, ak.num(my_jagged_array, axis=1))

Of course, you could avoid first flattening the array by having three nested loops in your Numba function, but then you have to start handling list offsets inside the Numba function, which is error prone. It's much more convenient to simplify the structure before calling Numba, and then rebuild it after the Numba function has run.

For posterity, a mechanism for doing this is to use the ak._util.recursively_apply helper to run your array logic at a particular array depth. This is easier than using ak.flatten and unflatten if you want to support arrays of varying dimensionality, e.g.

@nb.njit
def _sum_last_2d(array, result):
    for i, inner in enumerate(array):
        x = 0.0
        for y in inner:
            x += y
        result[i] = x

my_jagged_array = ak.Array([
    [
        [
            1, 2, 3, 4
        ]
    ],
    [
    ]
])

def sum_last(array):
    layout = ak.to_layout(array)

    # The "depth" of an Array Content starts at 1 (`axis=0`) and increases.
    # So, we want to act at `axis=-2`, which is equal to `axis=1`
    # This corresponds to `depth=2`. Clearly, the depth is always `axis + 1`.
    target_depth = layout.axis_wrap_if_negative(-2) + 1
    def getfunction(layout, depth):
        # If we're at the right depth
        if depth == target_depth:
            # Create our 1D output
            output = np.empty(len(layout))

            # Perform a sum over the final dimension of the current 2D (jagged) array
            _sum_last_2d(
                ak.Array(layout),
                output
            )

            # Wrap the NumPy array in an Awkward Content type, so it knows
            # How to handle it.
            new_layout = ak.layout.NumpyArray(output)

            # Use a special return type (callable) to indicate we changed the layout at this depth
            return lambda: new_layout

    new_layout = ak._util.recursively_apply(layout, getfunction)
    return ak._util.wrap(new_layout, ak._util.behaviorof(array))

result = sum_last(my_jagged_array)

This is not easy, and is rarely something you'll need. I use this approach most often when writing functions that other users call, as it gets rid of the flatten/unflatten logic.

tamasgal commented 2 years ago

Thanks for the quick answer!

I just realised I had copied the wrong MWE, I just corrected it.

Indeed the np.asarray() is a good workaround, so turning the (corrected) MWE:

@nb.njit
def foo(arr):
    return np.sum(arr.a)

into

@nb.njit
def foo(arr):
    y = np.asarray(arr.a)
    return np.sum(y)

I feel a bit dumb that I have not tried that ;)

Anyways, yes I am aware of the suggested workflow to try to do as much as possible outside of Numba (I do this already) but as always , the real use-case is difficult to describe and it felt a bit more natural to access the fields of the records directly instead of writing low-level loops or doing array conversions inside the Numba context. I cannot really do it outside due to the huge number of records and record fields I need to iterate over. Our data is really not in a good shape ;D

Yes, I already saw https://github.com/scikit-hep/awkward-1.0/blob/4c5fa98a585ec84b149ff9fbb8d9ff2e713d21f1/src/awkward/_connect/_numba/arrayview.py#L1183 but thought I am missing some other magic behind the scenes ;)

Alright, I'll have a look and figure out how to restructure it. I'll close this now and come back if I hit a wall.

jpivarski commented 2 years ago

Yeah, getting NumPy functions to just recognize Awkward Arrays as being arrays in Numba is tricky. There's an open issue about that: #509. One part of that might be to make Awkward Arrays as "ArrayLike" objects in Numba, but not all of Numba's NumPy overrides check for that abstract type: some directly check for concrete np.ndarray. I brought that up in the Numba weekly meeting and I think there's an issue about it. I think the issue is more general than just extension arrays that want to emulate NumPy arrays: I think it also doesn't work for arbitrary sequences. NumPy's np.sum would happily take a list, but I think it doesn't in a Numba-compiled function.

tamasgal commented 2 years ago

Ah, I totally missed your answer. Yes I kind of see where the problems are now. It felt like this should naturally be possible since the types are all well-defined but of course the devil's in the details. Anyways, I successfully converted that part with some dead simple np.asarray(...) calls and a bit of restructuring and now the data processing runs super fast on deeply nested jagged record arrays. I have a C++ counter-part to compare to and the performance is the same.

Thanks again, this library is so incredibly useful and necessary :)