Open mdhaber opened 1 month ago
sign for torch was already fixed at https://github.com/data-apis/array-api-compat/pull/137/files. I didn't realize cupy had the issue too. Do older versions of NumPy have this problem as well?
Yes. Basically everything needs to be patched unless it is recent enough. (I know the torch error would not be present with array_api_compat
main
, but this is what an environment on Colab looks like after !pip install array_api_compat array_api_strict
.
import array_api_compat
print(array_api_compat.__version__) # 1.8
from array_api_compat import numpy as xp
print(xp.__version__) # 1.26.4
x = xp.asarray(1 + 2j)
print(xp.sign(x)) # (1+0j)
import cupy as cp
print(cp.__version__) # 12.2.0
from array_api_compat import cupy as xp
x = xp.asarray(1 + 2j)
print(xp.sign(x)) # (1+0j)
import torch
print(torch.__version__) # 2.4.1+cu121
from array_api_compat import torch as xp
x = xp.asarray(1 + 2j)
# print(xp.sign(x)) # RuntimeError: Unlike NumPy, torch.sign is not intended to support complex numbers. Please use torch.sgn instead.
import dask
print(dask.__version__) # 2024.8.0
from array_api_compat.dask import array as xp
x = xp.asarray(1 + 2j)
print(xp.sign(x)) # dask.array<sign, shape=(), dtype=complex128, chunksize=(), chunktype=numpy.ndarray>
import array_api_strict as xp
print(xp.__version__) # 2.0.1
x = xp.asarray(1 + 2j)
print(xp.sign(x)) # (1+0j)
import jax
print(jax.__version__) # 0.4.26
import jax.numpy as xp
x = xp.asarray(1 + 2j)
print(xp.sign(x)) # (0.44721365+0.8944273j)
Basically jax.numpy
is the only thing that that works with the default installation and I see in the [Change Log](jax 0.4.24 (Feb 6, 2024)) that was only just updated in February.
Interesting. The test suite should be checking this as far as I can tell, but it hasn't come up, even though we do explicitly test against older versions of NumPy. That will require some investigation.
So I dug into this and it looks like the test suite has been ignoring any exceptions raised in the reference implementations in the elementwise function tests. This appears to affect quite a few functions, although it isn't clear yet if there are any actual unwrapped incompatibilities due to this other than this sign
one that you've pointed out.
Since the 2022.12 standard, the required implementation of
sign
has been:but I think only the most recent versions of libraries follow this (if any). Older versions of all libraries and even the most recent versions of some (e.g. CuPy, and even
array_api_strict
, which I can report separately if need be) use other conventions. It would be helpful if all libraries had aliases ofsign
that use the new definition.