data-apis / array-api-strict

Strict implementation of the Python array API (previously numpy.array_api)
http://data-apis.org/array-api-strict/
Other
7 stars 4 forks source link

Allow comparing any numeric types in boolean functions #50

Open asmeurer opened 1 month ago

asmeurer commented 1 month ago

Functions like equal, greater, and so on (and the operator equivalents) don't allow comparing non-promotable dtypes. This is particularly annoying because it makes it impossible to actually compare uint64 with int64, since the two cannot promote.

>>> import array_api_strict as xp
>>> xp.asarray(0, dtype=xp.int64) < xp.asarray(1, dtype=xp.uint64)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/aaronmeurer/Documents/array-api-strict/array_api_strict/_array_object.py", line 717, in __lt__
    other = self._check_allowed_dtypes(other, "real numeric", "__lt__")
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/aaronmeurer/Documents/array-api-strict/array_api_strict/_array_object.py", line 179, in _check_allowed_dtypes
    res_dtype = _result_type(self.dtype, other.dtype)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/aaronmeurer/Documents/array-api-strict/array_api_strict/_dtypes.py", line 217, in _result_type
    raise TypeError(f"{type1} and {type2} cannot be type promoted together")
TypeError: array_api_strict.int64 and array_api_strict.uint64 cannot be type promoted together

However, the standard doesn't actually say anywhere in greater or __gt__ that the input types must be promotable:

https://data-apis.org/array-api/latest/API_specification/generated/array_api.greater.html#greater https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__gt__.html

just that they should be real numeric. So in principle, these operators should even work when comparing floats and integers.

And equal allows any data type https://data-apis.org/array-api/latest/API_specification/generated/array_api.equal.html#equal, https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__eq__.html

It might be good to get some clarification in the standard about this, for instance, on how == should behave for mixing certain dtype combinations.

asmeurer commented 1 month ago

Would be good to get some standard clarification for equals https://github.com/data-apis/array-api/issues/819.

Although we can probably just fallback to what NumPy does for now. The only potential problem is pre-2.0 promotion behavior, which is another argument for making 2.0 a hard dependency #21. I also need to double check that NumPy 2.0 isn't internally promoting uint64 and int64 to float64, although if it is I doubt I can reasonably work around it.