data-apis / array-api-compat

Compatibility layer for common array libraries to support the Array API
https://data-apis.org/array-api-compat/
MIT License
73 stars 23 forks source link

`sign` complex case implementations #183

Open mdhaber opened 1 month ago

mdhaber commented 1 month ago

Since the 2022.12 standard, the required implementation of sign has been:

image

but I think only the most recent versions of libraries follow this (if any). Older versions of all libraries and even the most recent versions of some (e.g. CuPy, and even array_api_strict, which I can report separately if need be) use other conventions. It would be helpful if all libraries had aliases of sign that use the new definition.

asmeurer commented 4 weeks ago

sign for torch was already fixed at https://github.com/data-apis/array-api-compat/pull/137/files. I didn't realize cupy had the issue too. Do older versions of NumPy have this problem as well?

mdhaber commented 4 weeks ago

Yes. Basically everything needs to be patched unless it is recent enough. (I know the torch error would not be present with array_api_compat main, but this is what an environment on Colab looks like after !pip install array_api_compat array_api_strict.

import array_api_compat
print(array_api_compat.__version__)  # 1.8

from array_api_compat import numpy as xp
print(xp.__version__)  # 1.26.4
x = xp.asarray(1 + 2j)
print(xp.sign(x))  # (1+0j)

import cupy as cp
print(cp.__version__)  # 12.2.0
from array_api_compat import cupy as xp
x = xp.asarray(1 + 2j)
print(xp.sign(x))  # (1+0j)

import torch
print(torch.__version__)  # 2.4.1+cu121
from array_api_compat import torch as xp
x = xp.asarray(1 + 2j)
# print(xp.sign(x))  # RuntimeError: Unlike NumPy, torch.sign is not intended to support complex numbers. Please use torch.sgn instead.

import dask
print(dask.__version__)  # 2024.8.0
from array_api_compat.dask import array as xp
x = xp.asarray(1 + 2j)
print(xp.sign(x))  # dask.array<sign, shape=(), dtype=complex128, chunksize=(), chunktype=numpy.ndarray>

import array_api_strict as xp
print(xp.__version__)  # 2.0.1
x = xp.asarray(1 + 2j)
print(xp.sign(x))  # (1+0j)

import jax
print(jax.__version__)  # 0.4.26
import jax.numpy as xp
x = xp.asarray(1 + 2j)
print(xp.sign(x))  # (0.44721365+0.8944273j)

Basically jax.numpy is the only thing that that works with the default installation and I see in the [Change Log](jax 0.4.24 (Feb 6, 2024)) that was only just updated in February.

asmeurer commented 4 weeks ago

Interesting. The test suite should be checking this as far as I can tell, but it hasn't come up, even though we do explicitly test against older versions of NumPy. That will require some investigation.

asmeurer commented 5 days ago

So I dug into this and it looks like the test suite has been ignoring any exceptions raised in the reference implementations in the elementwise function tests. This appears to affect quite a few functions, although it isn't clear yet if there are any actual unwrapped incompatibilities due to this other than this sign one that you've pointed out.