data-apis / array-api-compat

Compatibility layer for common array libraries to support the Array API
https://data-apis.org/array-api-compat/
MIT License
69 stars 22 forks source link

Fix sign() for torch and cupy #137

Closed asmeurer closed 3 days ago

asmeurer commented 3 months ago

Neither propagate nans correctly, and torch does not support complex numbers.

Fixes https://github.com/data-apis/array-api-compat/issues/136

asmeurer commented 3 months ago

https://github.com/data-apis/array-api-compat/issues/136 should be resolved before this is merged, specifically, we should decide if it's worth fixing the sign(nan) special case, and if we want to keep that special case at all. Regardless of that, though, we should keep the torch complex handling, as it's very straightforward to implement.

asmeurer commented 3 months ago

It seems that torch has gained quite a few new test failures since the last time we ran them. I don't know if that's because of a test suite update or a torch update.

asmeurer commented 3 months ago

So based on a simple timing test on PyTorch CPU, is 3-10x slower than torch.sign, depending on how many nans are in the tensor. Although sign itself is a fast operation to begin with. But it would definitely be better for this to be fixed upstream.

asmeurer commented 3 days ago

It sounds like this will be useful, so I'm going to merge.