data-apis / array-api

RFC document, tooling and other content related to the array API standard
https://data-apis.github.io/array-api/latest/
MIT License
204 stars 42 forks source link

Python scalars in elementwise functions #807

Open shoyer opened 1 month ago

shoyer commented 1 month ago

The array API supports Python scalars in arithmetic only, i.e., operations like x + 1.

For the same readability reasons that supporting scalars in arithmetic is valuable, it would nice to also support Python scalars in other elementwise functions, at least those that take multiple arguments like maximum(x, 0) or where(y, x, 0).

rgommers commented 1 month ago

Hmm,I am in two minds about reconsidering this choice.

On the con side: non-array input to functions is going against the design we have had from the start, it makes static typing a bit harder (we'd need both an Array protocol and an ArrayOrScalar union), and not all libraries support it yet - PyTorch in particular. E.g.:

>>> import torch
>>> t = torch.ones(3)
>>> torch.maximum(t, 1.5)
...
TypeError: maximum(): argument 'other' (position 2) must be Tensor, not float

In principle PyTorch is fine with adding this it looks like, but it's a nontrivial amount of work and no one is working on it as far as I know: https://github.com/pytorch/pytorch/issues/110636. PyTorch does support it in functions matching operators (e.g., torch.add) and in torch.where.

TensorFlow also doesn't support it (except for in their experimental.numpy namespace IIRC), but that's less relevant now since it doesn't look like they're going to implement anything.

For the same readability reasons that supporting scalars in arithmetic is valuable

The readability argument is less prominent for functions that for operators though. Both because x + 1 is very short so the relative increase in characters is worse than for function calls (since modname.funcname is already long). Plus scalars are less commonly used in function calls.


On the pro side: I agree that it is pretty annoying to get right in a completely portable and generic way. In the cases where one does need it, the natural choice of asarray(scalar) often doesn't work, it should also use the dtype and device. So xp.maximum(x, 1) becomes:

xp.maximum(x, xp.asarray(1, dtype=x.dtype, device=x.device))

Hence if this is a pattern that a project happens to need a lot, it will probably create a utility function like:

def as_zerodim(value, x, /, xp=None):
    if xp is None:
        xp = array_namespace(x)
    return xp.asarray(value, dtype=x.dtype, device=x.device)

# Usage:
xp.maximum(x, as_zerodim(1, x))

PyTorch support comes through array-api-compat at this point, so wrapping the PyTorch functions isn't too hard. So it is doable. I think I'm +0.5 on balance. It's not the highest-prio item, but it's nice to have if it works for all implementing libraries.

asmeurer commented 1 month ago

We could support them in a bespoke way for specific useful functions' arguments like where. We already added scalar support specifically to the min and max arguments to clip https://data-apis.org/array-api/latest/API_specification/generated/array_api.clip.html

shoyer commented 1 month ago

On the pro side: I agree that it is pretty annoying to get right in a completely portable and generic way. In the cases where one does need it, the natural choice of asarray(scalar) often doesn't work, it should also use the dtype and device. So xp.maximum(x, 1) becomes:

xp.maximum(x, xp.asarray(1, dtype=x.dtype, device=x.device))

It's even a little messier in the case Xarray is currently facing:

  1. We want this to work in a completely portable and generic way, with the minimum array-API requirements.
  2. We also still want to allow libraries like NumPy to figure out the result type itself. For example, consider maximum(x, 0.5) in the case where x is an integer dtype. In the array API, mixed dtype casting is undefined, but in most array libraries the result would be upcast to some form of float.
asmeurer commented 1 month ago

In the array API, mixed dtype casting is undefined, but in most array libraries the result would be upcast to some form of float.

That's deviating from even the operator behavior in the array API. The specified scalar OP array behavior is to only upcast the scalar to the type of the array, not the other way around https://data-apis.org/array-api/latest/API_specification/type_promotion.html#mixing-arrays-with-python-scalars. In other words, int OP float_array is OK, but float OP int_array is not. Implicitly casing an integer array to a floating point dtype is cross-kind casting, and is something we've tried to explicitly avoid. (to be clear, these are all recommended, not required. Libraries like NumPy are free to implement this if they choose to)

Similarly, clip, which as I mentioned is an example of a function that already allows Python scalars, leaves mixed kind scalars unspecified, although I personally think it should adopt the same logic as operators and allow int alongside floating-point arrays.

asmeurer commented 3 weeks ago

List of potential APIs to support Python scalars in:

I've omitted all 1-argument elementwise functions.

Also these multi-argument elementwise functions, which seem less useful at a glance, but let me know if adding scalar support to any of these would be useful:

By the way, repeat() is also prior art in allowing both an array or int for the repeats argument.

rgommers commented 2 weeks ago

Thanks Aaron. I agree with the choices here: where, clip, copysign, nextafter, minimum and maximum are quite naturally written with one scalar input, and it's easy to find real-world examples of that.

maximum (this function is symmetrical, so it's not clear if support should be added to both arguments or just argument 2)

Since there must always be at least one input that is an array, it would be more annoying to statically type if both arguments may be scalar. It then requires overloads, just an array | scalar union for both arguments isn't enough. So it seems preferable to me to not make it symmetric.

asmeurer commented 2 weeks ago

Do you agree that allowing scalars for both arguments of nextafter is useful? At a glance, it seems to me that one could end up using it for a scalar in either argument, but I also haven't made much use of the function myself.

If so, should we allow scalars in both arguments? It could be useful, for instance, to just compute a specific float x + eps. The result should be a 0-D array. Also, unlike other cases, one wouldn't necessarily want to use math.nextafter since if your default floating-point dtype is float32, you would want x + float32_eps. OTOH, this would necessarily just automatically use the default floating-point dtype and the default device for the result. One can always manually cast one of the arguments to a 0-D array, though I don't know if that's an argument to allow or to not allow it.

mdhaber commented 2 weeks ago

Just wanted to add a +1 to this effort; it would really simplify translation efforts. If you need additional opinions about fine points, LMK.

shoyer commented 5 days ago

One consideration that came up in discussion: How can users write new elementwise functions that support scalars in some arguments themselves using the array API?

e.g., suppose we want polar_to_cartesian() to work with either r or theta being a scalar:

def polar_to_cartesian(r, theta):
    xp = get_array_namespace(r, theta)
    return (r * xp.sin(theta), r * xp.cos(theta))

This seems to require supporting scalars even in single-argument elementwise operations like sin and cos.

kgryte commented 5 days ago

@shoyer Your example may run into issues due to device selection. What device should libraries supporting multiple devices (e.g., PyTorch) allocate the returned arrays fromxp.sin(theta) and xp.cos(theta) to? And what if r is on a different device? That seems like could be a recipe for issues.

seberg commented 5 days ago

I find it a bit weird to not allow the scalar in both (would be nice if you could just add an (scalar, scalar) -> error overload, but dunno if that is possible and it probably doesn't matter much in practice.

Only allowing one of the two is a bit strange for assymetric functions. nextafter is probably a bit niche copysign maybe also atan2, hypot may not matter in a first iteration. So I have a doubt it is a true long term solution, but OK.

asmeurer commented 5 days ago

nextafter is used quite a bit in SciPy https://github.com/search?q=repo%3Ascipy%2Fscipy%20nextafter&type=code (although half the uses are in the tests), with things like nextafter(np.pi, np.inf), suggesting double scalar usage is common.

Ditto for copysign https://github.com/search?q=repo%3Ascipy%2Fscipy+copysign&type=code. In fact, I would say my suggestion above to only support argument 2 in copysign was wrong. If anything, it's more common to use copysign(0.0, x) to create a signed 0.

Based on the discussions in the meeting today, we should still require at least one argument to be an array for now (or rather, leave full scalar inputs unspecified), but I don't see any reason to prefer arguments 1 or 2 for nextafter or copysign. There was also some sentiment that breaking symmetry for minimum and maximum would be confusing, so we should perhaps allow scalars in either argument for those functions.

Not including scalar support for argument 1 in where, clip, or repeat should be fine I'd imagine, but if anyone is aware of use-cases otherwise please point them out.

rgommers commented 5 days ago

Based on the discussions in the meeting today,

The brief summary of that was:

asmeurer commented 5 days ago

There are also some functions I intentionally omitted from my list above because it doesn't really seem to make sense for them to include scalars, even though they support multiple arrays since they are array manipulation functions, like broadcast_arrays, stack, and concat. There's also functions that don't allow 0-D inputs at all, like searchsorted all the linalg functions, which should obviously be omitted from this list.

shoyer commented 5 days ago

I can maybe see a case for supporting scalars in broadcast_arrays and stack (NumPy and JAX support it), though it's pretty marginal.

Scalars don't make sense in concat because it requires arrays with an least one dimension.