scikit-hep / awkward

Manipulate JSON-like data with NumPy-like idioms.
https://awkward-array.org
BSD 3-Clause "New" or "Revised" License
829 stars 85 forks source link

Scalar type promotion not working #3128

Open nsmith- opened 4 months ago

nsmith- commented 4 months ago

Version of Awkward Array

2.6.4

Description and code to reproduce

In the following code

from typing import Annotated
import numpy as np
import awkward as ak
from enum import IntEnum

class ParticleOrigin(IntEnum):
    NonDefined: int = 0
    SingleElec: int = 1
    SingleMuon: int = 2

# works as expected
print(np.arange(10) == ParticleOrigin.SingleElec)
# errors
print(ak.Array(np.arange(10)) == ParticleOrigin.SingleElec)

numpy manages to recognize the IntEnum is promotable to int64 but awkward fails with the error:

Traceback (most recent call last):
  File "/Users/ncsmith/src/tmp.py", line 16, in <module>
    print(ak.Array(np.arange(10)) == ParticleOrigin.SingleElec)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ncsmith/src/commonenv/lib/python3.12/site-packages/awkward/_operators.py", line 53, in func
    return ufunc(self, other)
           ^^^^^^^^^^^^^^^^^^
  File "/Users/ncsmith/src/commonenv/lib/python3.12/site-packages/awkward/highlevel.py", line 1516, in __array_ufunc__
    return ak._connect.numpy.array_ufunc(ufunc, method, inputs, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ncsmith/src/commonenv/lib/python3.12/site-packages/awkward/_connect/numpy.py", line 466, in array_ufunc
    out = ak._broadcasting.broadcast_and_apply(
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ncsmith/src/commonenv/lib/python3.12/site-packages/awkward/_broadcasting.py", line 968, in broadcast_and_apply
    out = apply_step(
          ^^^^^^^^^^^
  File "/Users/ncsmith/src/commonenv/lib/python3.12/site-packages/awkward/_broadcasting.py", line 946, in apply_step
    return continuation()
           ^^^^^^^^^^^^^^
  File "/Users/ncsmith/src/commonenv/lib/python3.12/site-packages/awkward/_broadcasting.py", line 915, in continuation
    return broadcast_any_list()
           ^^^^^^^^^^^^^^^^^^^^
  File "/Users/ncsmith/src/commonenv/lib/python3.12/site-packages/awkward/_broadcasting.py", line 622, in broadcast_any_list
    outcontent = apply_step(
                 ^^^^^^^^^^^
  File "/Users/ncsmith/src/commonenv/lib/python3.12/site-packages/awkward/_broadcasting.py", line 928, in apply_step
    result = action(
             ^^^^^^^
  File "/Users/ncsmith/src/commonenv/lib/python3.12/site-packages/awkward/_connect/numpy.py", line 432, in action
    result = backend.nplike.apply_ufunc(ufunc, method, input_args, kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ncsmith/src/commonenv/lib/python3.12/site-packages/awkward/_nplikes/array_module.py", line 208, in apply_ufunc
    return self._apply_ufunc_nep_50(ufunc, method, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ncsmith/src/commonenv/lib/python3.12/site-packages/awkward/_nplikes/array_module.py", line 235, in _apply_ufunc_nep_50
    resolved_dtypes = ufunc.resolve_dtypes(arg_dtypes)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: Provided dtype must be a valid NumPy dtype, int, float, complex, or None.

This error occurred while calling

    numpy.equal.__call__(
        <Array [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] type='10 * int64'>
        <ParticleOrigin.SingleElec: 1>
    )

cc @kratsg

agoose77 commented 4 months ago

This should be supported, but currently fails. Even the Array API (which we don't promise to confirm to, but take as inspiration on the promotion rules) supports this: https://data-apis.org/array-api/latest/API_specification/type_promotion.html#mixing-arrays-with-python-scalars

I will action this probably over the weekend.

kratsg commented 4 months ago

Note that

print(ak.Array(np.arange(10)) == ParticleOrigin.SingleElec.value)

still works (as in, regular Enums are seemingly fine).