data-apis / array-api

RFC document, tooling and other content related to the array API standard
https://data-apis.github.io/array-api/latest/
MIT License
204 stars 42 forks source link

result_type() for mixed arrays/Python scalars #805

Open shoyer opened 1 month ago

shoyer commented 1 month ago

The array API's type promotion rules support mixed scalar/array operations, e.g., 1 + xp.arange(3).

For Xarray, we would like to be able to figure out the resulting dtype from this sort of operation before actually doing it (https://github.com/pydata/xarray/pull/8946).

Ideally, we could use xp.result_type() for this purpose, but as documented result_type only supports arrays and dtype objects. Could we potentially extend result_type to also handle Python scalars? It is worth noting that this already works today in NumPy, e.g.,

>>> np.result_type(1, np.arange(3))
dtype('int64')
asmeurer commented 1 month ago

This makes sense to me. torch seems to support this as well. What should the result be if there are multiple Python scalars? Undefined?

shoyer commented 1 month ago

What should the result be if there are multiple Python scalars? Undefined?

This should indeed probably be undefined by the spec.

In most cases I imagine array libraries will have a default dtype, but different libraries will make different choices (e.g., int32 in JAX vs int64 in NumPy):

>> np.result_type(1, 2)
dtype('int64')
>> jnp.result_type(1, 2)
dtype('int32')
asmeurer commented 1 month ago

One concern I see with this is that libraries need not support Python scalars in functions, only for operators. So result_type(a, b) working does not imply that func(a, b) will work.

shoyer commented 1 month ago

One concern I see with this is that libraries need not support Python scalars in functions, only for operators. So result_type(a, b) working does not imply that func(a, b) will work.

In Xarrray, we are thinking of defining something like:

def as_shared_dtype(scalars_or_arrays):
    xp = get_array_namespace(scalars_or_arrays)
    dtype = xp.result_type(*scalars_or_arrays)
    return tuple(xp.asarray(x, dtype) for x in scalars_or_arrays)
asmeurer commented 1 month ago

Does xarray automatically call asarray on scalar function arguments like NumPy does? Certainly the recommendation of the standard is to not do that, because it's cleaner from a typing perspective. Implicitly calling asarray at the top of every function is considered a historical NumPy antipattern. It's not disallowed, but we also should probably avoid standardizing things that encourage it.

keewis commented 1 month ago

the only time we call that function is when preparing arguments for where (and for concat / stack, but there we don't expect to encounter python scalars), which as far as I can tell doesn't support python scalars.

shoyer commented 1 month ago

Xarray objects always contain array objects, but indeed there are functions like where() for which it's convenient to be able to use scalars.

I opened a separate issue to discuss: https://github.com/data-apis/array-api/issues/807

rgommers commented 1 month ago

This sounds like a useful change to me.

What should the result be if there are multiple Python scalars? Undefined?

This should indeed probably be undefined by the spec.

What is the problem? It seems well-defined to allow multiple. If multiple arrays and dtype objects are allowed, why not multiple Python scalars?

keewis commented 1 month ago

I'm not sure, but I think that was referring to a situation where you have no explicit dtypes, just (compatible) python scalars. In that case, we'd have to make an arbitrary choice (or raise an error).

rgommers commented 1 month ago

Ah of course. Agreed, there must be at least one array or dtype object.

rgommers commented 1 month ago

Making this change to result_type seemed fair to everyone in the discussion we just had. Given that our type promotion rules include Python scalars, the function that can be used to apply those promotion rules should support them as well.