data-apis / array-api-compat

Compatibility layer for common array libraries to support the Array API
https://data-apis.org/array-api-compat/
MIT License
69 stars 22 forks source link

`sign` special case implementations #136

Closed mdhaber closed 3 days ago

mdhaber commented 4 months ago

According to v2022.12 (and v2023.12) of the array API standard, the special cases of sign include:

For real-valued operands... If x_i is NaN, the result is NaN.

However,

from array_api_compat import numpy as np, cupy as cp, torch
np.sign(np.asarray(np.nan))  # nan
cp.sign(cp.asarray(cp.nan))  # array( 0.000e+00)
torch.sign(torch.asarray(torch.nan))  # tensor(0.)

There may be other special cases that are not yet implemented. I haven't done a complete review, but I noticed that torch gives an error when the input is complex.

torch.sign(torch.asarray(1+1j))
# RuntimeError: Unlike NumPy, torch.sign is not intended to support complex numbers. Please use torch.sgn instead.
asmeurer commented 4 months ago

I haven't tried to do any special-cases workarounds here yet. I guess if this is causing issues for you we can add a workaround.

asmeurer commented 4 months ago

We definitely should add complex support to torch.sign. It sounds like that will just be an easy wrapper of torch.sgn.

mdhaber commented 4 months ago

Yeah I actually did run into this special case while converting code from NumPy to array API. The purpose of the code is to produce a different status flag that depends on whether an element is positive, negative, zero, or NaN, so no wonder I ran into it. So it would be helpful to add the special case, but support for 2023.12 would be higher priority to me.

asmeurer commented 3 months ago

We can add it. The main concern with adding special-case handling is it means adding a mask to the function, so it could be a minor hit to performance. I would at least make sure there are upstream issues about this to the libraries that don't implement it correctly.

rgommers commented 3 months ago

The main concern with adding special-case handling is it means adding a mask to the function, so it could be a minor hit to performance.

I think that won't be a minor hit. For element-wise functions, adding isnan checks and masking may slow things down by 2x or more.

In general, these special cases have not been well validated yet, so I'd be quite reluctant to assume they are all correctly specified or support them in the compat layer.

In this case, it may be possible to fix in NumPy. There's a bunch of special-casing to make sign(nan) return nan, which looks like old code and disagrees with C99's signbit on purpose:

https://github.com/numpy/numpy/blob/2a9b9134270371b43223fc848b753fceab96b4a5/numpy/_core/src/umath/loops.c.src#L1341-L1350

https://github.com/numpy/numpy/blob/2a9b9134270371b43223fc848b753fceab96b4a5/numpy/_core/src/umath/loops.c.src#L2262-L2271

Changing it would be a minor backwards-compat break, but probably an improvement. Much larger changes were made to np.sign recently in gh-https://github.com/numpy/numpy/pull/25441, however the sign(nan) = nan was not discussed there. It does seem like sign deviates from signbit for no good reason; it was probably an ad-hoc pre-C99 decision.

asmeurer commented 3 months ago

NumPy is the one that agrees with what the array API says. If we think sign(nan) should not return NaN, then the array API should change.

rgommers commented 3 months ago

If we think sign(nan) should not return NaN, then the array API should change.

I'm honestly not sure. These special cases are a pain, it requires a lot of time investment to figure out what was discussed before, and why different libraries end up with different return values. I suspect the following:

seberg commented 3 months ago

I think one normal rule of NaNs is that they are propagated unless there is a clear reason why not. sign is very different from signbit after all as that is defined to return a bool and additionally maps to an implementation detail of IEEE float representation.

Further, returning NaN preserves the full "partial order" (e.g. in C++20) of value <=> 0 (less, equal, greater, and unordered).

So IMO, NumPy does the right thing unless there is a good argument why typical use-cases would expect a 0 return for NaN and it sounds like the report here is a use-case where you want the full partial order to be preserved!

TBH, I would lean towards torch should fix this, but if there is a good reason why they don't want to (what is it?), then it has to stay undefined which seems unfortunate for the actual use-case above.

asmeurer commented 3 months ago

I found this old discussion https://mail.python.org/archives/list/numpy-discussion@python.org/thread/A2JFHOZZOF634CNZ7E27THQEBU4EZFTS/#F3KS7QWVPIXSYB7CSSY37OXYM4JVZTZQ

sign and signbit are completely different things for complex numbers, as I pointed out at https://mail.python.org/archives/list/numpy-discussion@python.org/message/VBYOVSTN2GTBPEJ3OPDS2S5DLPQJFFX3/ It's probably not a coincidence that torch also doesn't define complex sign.

It does seem valuable to figure this out, since there's at least one real-world use-case. Maybe we should move this discussion to the array-api repo.

asmeurer commented 3 months ago

I have a fix at https://github.com/data-apis/array-api-compat/pull/137, which we can at least use to check the performance implications.

asmeurer commented 3 months ago

By the way, you can see the other special cases that are not being followed in the xfails files (we do not yet attempt to fix any of them, and most are on operators which can't be fixed anyways):

torch https://github.com/data-apis/array-api-compat/blob/376038ed9f4337cdec78f21a5ccb3e2b6d948786/torch-xfails.txt#L89-L179

cupy https://github.com/data-apis/array-api-compat/blob/376038ed9f4337cdec78f21a5ccb3e2b6d948786/cupy-xfails.txt#L61-L169

numpy https://github.com/data-apis/array-api-compat/blob/376038ed9f4337cdec78f21a5ccb3e2b6d948786/numpy-xfails.txt#L18-L41

I will say that even though there are quite a few of these, the sign special case seems to stand out. Almost all the rest seem to have to do with handling -0 correctly.

kgryte commented 3 months ago

@rgommers The addition of the special case for NaN handling in sign comes from https://github.com/data-apis/array-api/pull/556.

The specification is correct on this, as returning NaN follows naturally from the definition of the signum function where

$$\textrm{sgn\ } x = \frac{x}{|x|}$$

and where NaN/NaN follows the special cases for division, thus ensuring arithmetic consistency.

This was not a pre-C99 oversight and was intentional. We shouldn't expect the signum and signbit functions to return equivalent results, and NumPy is correct here. It would perhaps have been better if NumPy had chosen signum as the name rather than sign to make this delineation in behavior more clear.

seberg commented 3 months ago

Since its maybe interesting. Turns out getting NaNs may make this faster on the GPU: https://github.com/cupy/cupy/issues/8327

mdhaber commented 1 week ago

Just thought I'd mention that I ran into this again in a different context.