data-apis / array-api-compat

Compatibility layer for common array libraries to support the Array API
https://data-apis.org/array-api-compat/
MIT License
75 stars 25 forks source link

BUG: fix `torch.result_type` cross-kind promotion #55

Closed lucascolley closed 1 year ago

lucascolley commented 1 year ago

Reference comment: https://github.com/scipy/scipy/pull/19051#issuecomment-1699295300 @rgommers

Expected behaviour: array_api_compat.torch.result_type carries out cross-kind promotion like torch.result_type.

Observed behaviour:

In [1]: import torch

In [2]: t = torch.tensor([[0, 2], [1, 1], [2, 0]]).T

In [3]: from array_api_compat import array_namespace

In [4]: xp = array_namespace(t)

In [5]: xp.result_type(t, xp.float64)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[5], line 1
----> 1 xp.result_type(t, xp.float64)

File ~/dev/array-api-compat/array_api_compat/torch/_aliases.py:136, in result_type(*arrays_and_dtypes)
    131     return _promotion_table[xdt, ydt]
    133 # This doesn't result_type(dtype, dtype) for non-array API dtypes
    134 # because torch.result_type only accepts tensors. This does however, allow
    135 # cross-kind promotion.
--> 136 return torch.result_type(x, y)

TypeError: result_type() received an invalid combination of arguments - got (Tensor, torch.dtype), but expected one of:
 * (Tensor tensor, Tensor other)
      didn't match because some of the arguments have invalid types: (Tensor, torch.dtype)
 * (Number scalar, Tensor tensor)
      didn't match because some of the arguments have invalid types: (Tensor, torch.dtype)
 * (Tensor tensor, Number other)
      didn't match because some of the arguments have invalid types: (Tensor, torch.dtype)
 * (Number scalar1, Number scalar2)
      didn't match because some of the arguments have invalid types: (Tensor, torch.dtype)

Improved behaviour on this branch:

In [1]: import torch

In [2]: t = torch.tensor([[0, 2], [1, 1], [2, 0]]).T

In [3]: from array_api_compat import array_namespace

In [4]: xp = array_namespace(t)

In [5]: xp.result_type(t, xp.float64)
Out[5]: torch.float64
rgommers commented 1 year ago

Okay, this is almost certainly correct and it has been open for a week. So I'll hit the green button here. Thanks @lucascolley!