dask-contrib / dask-awkward

Native Dask collection for awkward arrays, and the library to use it.
https://dask-awkward.readthedocs.io
BSD 3-Clause "New" or "Revised" License
61 stars 19 forks source link

Cannot subscript by `argsort` output #151

Closed masonproffitt closed 1 year ago

masonproffitt commented 1 year ago

For example:

>>> import awkward as ak, dask_awkward as dak
>>> a = ak.Array([[], [3, 1, 2]])
>>> da = dak.from_awkward(a, 1)
>>> dak.argsort(da).compute()
<Array [[], [1, 2, 0]] type='2 * var * int64'>
>>> da[dak.argsort(da)]
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/user/iris-hep/src/dask-awkward/src/dask_awkward/lib/core.py", line 896, in __getitem__
    return self._getitem_single(where)
  File "/home/user/iris-hep/src/dask-awkward/src/dask_awkward/lib/core.py", line 867, in _getitem_single
    raise DaskAwkwardNotImplemented(f"__getitem__ doesn't support where={where}.")
dask_awkward.utils.DaskAwkwardNotImplemented: __getitem__ doesn't support where=dask.awkward<argsort, npartitions=1>.

If you would like this unsupported call to be supported by
dask-awkward please open an issue at:
https://github.com/dask-contrib/dask-awkward.

I would expect the last line to be equivalent to dak.sort(da). This kind of thing is my main use case for argsort. Specifically, being able to sort one field of an Array by using a different field as the key for ordering.

jpivarski commented 1 year ago

The default axis for ak.argsort is -1. Implementing axis != 0 (this case) is more straightforward than axis == 0.

jpivarski commented 1 year ago

Needs another elif case for

https://github.com/dask-contrib/dask-awkward/blob/baa91e946649dbed7d9a539e41b38fc1368cefd1/src/dask_awkward/lib/core.py#L841-L872

jpivarski commented 1 year ago

Determining if axis is equivalent to zero if it's negative...

Given this array,

>>> array = ak.Array([[{"x": 1.1, "y": [1]}], [], [{"x": 2.2, "y": [1, 2]}]])
>>> array.type.show()
3 * var * {
    x: float64,
    y: var * int64
}

For x, axis=-2 is equivalent to axis=0:

>>> ak.sum(array.x, axis=0), ak.sum(array.x, axis=-2)
(<Array [3.3] type='1 * float64'>, <Array [3.3] type='1 * float64'>)

For y, axis=-3 is equivalent to axis=0:

>>> ak.sum(array.y, axis=0), ak.sum(array.y, axis=-3)
(<Array [[2, 2]] type='1 * var * int64'>, <Array [[2, 2]] type='1 * var * int64'>)

We could derive that from the minmax_depth:

>>> array.layout.minmax_depth
(2, 3)

which doesn't say which has depth 2 and which has depth 3, just that something does. If a given axis is negative and

This last case is one that you can always offload to Awkward. If

The only other case, -axis > min(array.layout.minmax_depth) would even be an Awkward axis out of range error (np.AxisError, but maybe ValueError?).

>>> ak.sum(array.y, axis=-4)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/jpivarski/irishep/awkward/src/awkward/operations/ak_sum.py", line 218, in sum
    return _impl(array, axis, keepdims, mask_identity, highlevel, behavior)
  File "/home/jpivarski/irishep/awkward/src/awkward/operations/ak_sum.py", line 292, in _impl
    out = ak._do.reduce(
  File "/home/jpivarski/irishep/awkward/src/awkward/_do.py", line 383, in reduce
    raise ak._errors.wrap_error(
ValueError: while calling

    ak.sum(
        array = <Array [[[1]], [], [[1, 2]]] type='3 * var * var * int64'>
        axis = -4
        keepdims = False
        mask_identity = False
        highlevel = True
        behavior = None
    )

Error details: axis=-4 exceeds the depth of the nested list structure (which is 3)

I expected that to be an np.AxisError, not a ValueError...

douglasdavis commented 1 year ago

Thanks for that explanation, Jim! I'll open another issue to track the need for handling negative dimension == zero