Expected behaviour: array_api_compat.torch.result_type carries out cross-kind promotion like torch.result_type.
Observed behaviour:
In [1]: import torch
In [2]: t = torch.tensor([[0, 2], [1, 1], [2, 0]]).T
In [3]: from array_api_compat import array_namespace
In [4]: xp = array_namespace(t)
In [5]: xp.result_type(t, xp.float64)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[5], line 1
----> 1 xp.result_type(t, xp.float64)
File ~/dev/array-api-compat/array_api_compat/torch/_aliases.py:136, in result_type(*arrays_and_dtypes)
131 return _promotion_table[xdt, ydt]
133 # This doesn't result_type(dtype, dtype) for non-array API dtypes
134 # because torch.result_type only accepts tensors. This does however, allow
135 # cross-kind promotion.
--> 136 return torch.result_type(x, y)
TypeError: result_type() received an invalid combination of arguments - got (Tensor, torch.dtype), but expected one of:
* (Tensor tensor, Tensor other)
didn't match because some of the arguments have invalid types: (Tensor, torch.dtype)
* (Number scalar, Tensor tensor)
didn't match because some of the arguments have invalid types: (Tensor, torch.dtype)
* (Tensor tensor, Number other)
didn't match because some of the arguments have invalid types: (Tensor, torch.dtype)
* (Number scalar1, Number scalar2)
didn't match because some of the arguments have invalid types: (Tensor, torch.dtype)
Improved behaviour on this branch:
In [1]: import torch
In [2]: t = torch.tensor([[0, 2], [1, 1], [2, 0]]).T
In [3]: from array_api_compat import array_namespace
In [4]: xp = array_namespace(t)
In [5]: xp.result_type(t, xp.float64)
Out[5]: torch.float64
Reference comment: https://github.com/scipy/scipy/pull/19051#issuecomment-1699295300 @rgommers
Expected behaviour:
array_api_compat.torch.result_type
carries out cross-kind promotion liketorch.result_type
.Observed behaviour:
Improved behaviour on this branch: