data-apis / array-api

RFC document, tooling and other content related to the array API standard
https://data-apis.github.io/array-api/latest/
MIT License
205 stars 42 forks source link

Reconsider sum/prod/trace upcasting for floating-point dtypes #731

Closed rgommers closed 4 months ago

rgommers commented 5 months ago

The requirement to upcast sum(x) to the default floating-point dtype with the default dtype=None currently says (from the sum spec):

If x has a real-valued floating-point data type, the returned array must have the default real-valued floating-point data type.

The rationale given is "keyword argument is intended to help prevent data type overflows.". This came up again in the review of NEP 56 (https://github.com/numpy/numpy/pull/25542), and is basically the only part of the standard that was flagged as problematic and explicitly rejected.

I agree that the standard's choice here is problematic, at least from a practical perspective: no array library does this, and none are planning to implement this. And the rationale is pretty weak, it just does not apply to floating-point dtypes to a similar extent as it does to integer dtypes (and for integers, array libraries do implement the upcasting). Examples:

>>> # NumPy:
>>> np.sum(np.ones(3, dtype=np.float32)).dtype
dtype('float32')
>>> np.sum(np.ones(3, dtype=np.int32)).dtype
dtype('int64')

>>> # PyTorch:
>>> torch.sum(torch.ones(2, dtype=torch.bfloat16)).dtype
torch.bfloat16
>>> torch.sum(torch.ones(2, dtype=torch.int16)).dtype
torch.int64

>>> # JAX:
>>> jnp.sum(jnp.ones(4, dtype=jnp.float16)).dtype
dtype('float16')
>>> jnp.sum(jnp.ones(4, dtype=jnp.int16)).dtype
dtype('int32')

>>> # CuPy:
>>> cp.sum(cp.ones(5, dtype=cp.float16)).dtype
dtype('float16')
>>> cp.sum(cp.ones(5, dtype=cp.int32)).dtype
dtype('int64')

>>> # Dask:
>>> da.sum(da.ones(6, dtype=np.float32)).dtype
dtype('float32')
>>> da.sum(da.ones(6, dtype=np.int32)).dtype
dtype('int64')
>>> 

The most relevant conversation is https://github.com/data-apis/array-api/pull/238#issuecomment-922130293. There was some further minor tweaks (without much discussion) in gh-666.

Proposed resolution: align the standard with what all known array libraries implement today.

seberg commented 5 months ago

As I mentioned a few times before, I agree with not specifying this. Partly, because I think it is just asking too much from NumPy (and apparently others). But even from an Array API perspective I think it isn't helpful, because "default type" also is just "unspecified" effectively (if you sum a float32 array, you don't if you get a float32 or float64).

rgommers commented 5 months ago

Partly, because I think it is just asking too much from NumPy (and apparently others).

Since all libraries appear to do exactly the same as of today, what's the problem with encoding that? Nothing is being asked from anyone at that point, it's basically just documenting the status quo.

seberg commented 5 months ago

I wouldn't have been surprised if someone upcast for float16, but if not then fine. Mainly, I am not sure I mind the old proposal if it was coming from scratch, so I don't have an opinion about allowing it (i.e. not caring that the result may have higher precision).

mhvk commented 5 months ago

My2¢ is that it is good to codify the current behaviour for the various float. It is really surprising if the dtype of a reduction depends on the operation.

p.s. Indeed, I think this is true even for integers. At least, to me, the following is neither logical nor expected:

In [17]: np.subtract.reduce(np.arange(4, dtype='i2')).dtype
Out[17]: dtype('int16')

In [18]: np.add.reduce(np.arange(4, dtype='i2')).dtype
Out[18]: dtype('int64')

Explicit is better than implicit and all that. And for reductions, it might be quite reasonable to do the operation at higher precision and check for overflow before downcasting.

asmeurer commented 5 months ago

And for reductions, it might be quite reasonable to do the operation at higher precision and check for overflow before downcasting.

If you're suggesting a value-based result type, that's even worse. That's the sort of thing we're trying to get away from with the standard.

asmeurer commented 5 months ago

The rationale given is "keyword argument is intended to help prevent data type overflows.". This came up again in the review of NEP 56 (https://github.com/numpy/numpy/pull/25542), and is basically the only part of the standard that was flagged as problematic and explicitly rejected.

That PR discussion is huge and you didn't point to a specific comment, so I don't know what was already said. But it makes sense to me to treat floats different from ints because floats give inf when they overflow, which is a very clear indication to the user that they need to manually upcast.

rgommers commented 5 months ago

That PR discussion is huge and you didn't point to a specific comment, so I don't know what was already said

There's several comments on it. The main one is this comment. Then it got also mixed in with the comment on in-place operator behavior in this comment. And in this comment @seberg said "(I explicitly did not give a thumbs-up for the type promotion changes in that meeting)" (type promotion meaning the sum/prod ones).

I did write it down as one requirement among many (I didn't quite agree with what I wrote myself, but forgot to revisit), it didn't stand out in the text. It's telling that it was flagged quickly by both @seberg and @mhvk as too impactful to change.

And for reductions, it might be quite reasonable to do the operation at higher precision and check for overflow before downcasting.

If you're suggesting a value-based result type, that's even worse. That's the sort of thing we're trying to get away from with the standard.

Internal upcasting is regularly done, and perfectly fine. I assume the intent was "warn or raise on integer overflow", rather than value-based casting.

mhvk commented 5 months ago

If you're suggesting a value-based result type, that's even worse. That's the sort of thing we're trying to get away from with the standard.

No, not a different type, that would be awful indeed! But an over/underflow error/warning, just like we can currently get for floating point operations. For regular ufuncs, that is too much of a performance hit, but for reductions, it should not be. And reductions are a bit special already since it definitely makes sense to do things at higher precision internally, before casting back to the original precision.

asmeurer commented 5 months ago

p.s. Indeed, I think this is true even for integers. At least, to me, the following is neither logical nor expected:

ufunc.reduce is not part of the standard, so it's not really relevant here, but FWIW, I agree with you that it's quite surprising for ufunc.reduce to not return the same dtype as the ufunc itself in some cases. I think of ufunc methods as being somewhat "low-level" things that shouldn't try to be overly smart, at least in terms of behavior (ufunc-specific performance optimizations are another thing).

OTOH sum is a distinct function from add and is a more of an end-user function, so I don't know if the argument applies there.

leofang commented 5 months ago

Proposed resolution: align the standard with what all known array libraries implement today.

@rgommers What would the new wording that you seek to change to?

rgommers commented 5 months ago

The current wording that is problematic is:

I suggest a change like this:

This loosens the spec, recommends what the current behavior of all known libraries is, and still allows upcasting if an implementation desires to do so for (reasons).

seberg commented 5 months ago

the returned array should have either the same dtype as x (recommended) or a higher-precision dtype of the same kind as the dtype of x

Thanks, looks good to me. Maybe it would be slightly clearer to replace the or ... with a new sentence: If it is not the same dtype it must be a higher-precision...? (because if you apply the should also to the "or ..." part, it would be a must)

EDIT: Or just replace the should with a must, to me it seems to apply to the full construct, so must is correct and the (recommended) already includes the "should" part.

asmeurer commented 5 months ago

I personally don't see value in hedging with "recommended" or "should" if no one actually does that now and we don't even have a concrete reason for anyone to do so. It feels like our only real rationale is some misunderstanding in the original discussion about int dtypes. Not being precise about dtypes has disadvantages. For instance, it makes it harder to reason about dtypes statically (https://github.com/data-apis/array-api/issues/728). Everywhere else in the standard uses "must" for output dtypes (correct me if I'm wrong on that).

And I also disagree that upcasting is not a big deal. When you're explicitly using a lower precision float silent or unexpected upcasting can have a very real performance impact. Here's an example where fixing an unexpected float64 upcast made some code 5x faster https://github.com/jaymody/picoGPT/pull/12.

seberg commented 5 months ago

I am fine with being strict here and saying it must be the same: it is the only version that I see giving any clarity to libraries supporting multiple implementations (which is my main emphasis here always, compared to thinking about the ideal end-user API).

But, there must have been some feeling of float16 and float32 having loss of precision quickly and that users need protecting, so that this ended up written down. And I am happy to accept the opinion that it may be a reasonable choice for end-users. Although, I can see the argument that this is really about intermediate precision of the reduction, an argument that could even be made for integers: So long as you detect the overflow (by aggregating at high/arbitrary precision), forcing users to upcast manually isn't that terrible.

treat floats different from ints because floats give inf when they overflow, which is a very clear indication to the user that they need to manually upcast.

N.B.: To clarify, for summation overflows are actuall not the main issue! The issue is extreme loss of precision unless you have a high precision intermediate (at least float64). If you sum float32(1.) naively the result just caps around 2**23 == 8388608 and you may see a decent amount of loss earlier. A million elements are not that odd in many contexts.

rgommers commented 4 months ago

Okay, seems like there is support for "must", and I agree that that is nicer. PR for that: gh-744.

asmeurer commented 4 months ago

N.B.: To clarify, for summation overflows are actuall not the main issue! The issue is extreme loss of precision unless you have a high precision intermediate (at least float64). If you sum float32(1.) naively the result just caps around 2**23 == 8388608 and you may see a decent amount of loss earlier. A million elements are not that odd in many contexts.

This can be solved by using a higher intermediate precision, or by using a smarter summation algorithm. My point is that the only reason you'd need a higher result precision is if there is an overflow.