Closed rgommers closed 3 years ago
I feel part of this issue is related to the fact that different implementations handle the intermediate type during reduction differently. On the C++ side this is also a known problem and there is an ongoing proposal to fix this ambiguity (and all the associated issues) by requiring the intermediate type to be either the dtype of the initial value (if there's one) or the input iterator's value type (https://wg21.link/P0571). I would suggest to follow the C++ behavior if possible (assuming if the proposal would be accepted) as most libraries have a C++ implementation under the hood. The net effect is likely leaning toward keeping the input dtype.
Pinging the proposal author @brycelelbach for awareness (maybe he could comment on potential pitfalls or challenges).
On the CuPy side, I see no reason to not follow this behavior except that we want to be NumPy compliant in the main namespace.
Interesting, thanks @leofang. That proposal goes into a lot of depth about the intermediary type, which matters for implementers but is something we should be agnostic about here imho. The only thing that matters is the output dtype, because that's user-observable behavior. If someone wants to write an implementation where the intermediate type is long double
for all input dtypes, and then cast the final result back to the correct output dtype, that should be perfectly fine.
I'm very unsure about what the output dtype should be. Preserving input dtype sounds nice in theory, but I expect the TF/MXNet behavior to be a foot gun in practice. Take for example uint8
and uint16
- the main use case for those are image processing. Any practically relevant sum(uint8_image)
is going to overflow, and many uint16
images will too. So it seems like a very impractical choice.
Reductions in numpy have an initial
keyword (see, e.g. docs for sum), which seems necessary to force an upcast in case the behavior would be "keep input dtype", but is hardly used with numpy code because the default behavior is fine. TF reduce_sum doesn't have this. I couldn't find an issue about this on the TF issue tracker so quickly.
Side note, NumPy does have something weird going on as well, it's happy to use scalar negative values in unsigned integer reductions; unclear to me why:
>>> image = np.ones((100, 2), dtype=np.uint8)
>>> np.sum(image, axis=0)
array([100, 100], dtype=uint64)
>>> np.sum(image, axis=0, initial=1)
array([101, 101], dtype=uint64)
>>> np.sum(image, axis=0, initial=-1)
array([99, 99], dtype=uint64)
>>> np.sum(image, axis=0, initial=np.array(-1))
...
TypeError: Cannot cast scalar from dtype('int64') to dtype('uint64') according to the rule 'safe'
Searching for "reduce_sum
overflow" is hard, cause "Stack Overflow". But this shows exactly what I thought TF would struggle with: https://github.com/tensorflow/tensorflow/commit/23fde233bf3210759b5a4453bc39101df9c86d0c, https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/reduction_ops.h#L60.
// Specialization for which we do the reduction in IntermediateType to
// avoid integer overflow.
#define CASTING_SPECIALIZATION(ScalarType, IntermediateType)
...
CASTING_SPECIALIZATION(uint8, uint64);
CASTING_SPECIALIZATION(uint16, uint64);
<etc>
And same for floating point:
// Specialization for BF16 Reducer to fix accuracy.
// TODO: All BF16 reducers should have specializations to fix accuracy.
JAX discussion: https://github.com/google/jax/issues/3154. That leans towards making a change from its current behavior to preserving input dtype.
The discussion there does make sense for, e.g., float32
. The trouble is that for 32-bit dtypes, preserving input dtypes seems clearly better, but for lower-precision dtypes it will almost never be the right thing to do.
EDIT: copying the example from that issue to show JAX behavior for all dtypes:
In [1]: import jax.numpy as jnp
...: from jax.test_util import dtypes
...: from jax import config; config.update('jax_enable_x64', True)
...: for dtype in dtypes.all:
...: print(dtype.__name__, "->", jnp.zeros(2, dtype).sum().dtype)
...:
bfloat16 -> bfloat16
float16 -> float16
float32 -> float32
float64 -> float64
int8 -> int64
int16 -> int64
int32 -> int64
int64 -> int64
uint8 -> uint64
uint16 -> uint64
uint32 -> uint64
uint64 -> uint64
complex64 -> complex64
complex128 -> complex128
bool_ -> int64
Also, JAX has a global setting to switch default dtype to 64-bit??
Preserving input dtype sounds nice in theory, but I expect the TF/MXNet behavior to be a foot gun in practice.
Personally, I think pushing considerations of overflow to userland is fine. If a user has reason to be concerned about overflow during summation or multiplication, then explicitly casting an array to dtype capable of handling larger values without overflow should be okay. That a user would be forced to explicitly think about desired dtypes is not necessarily a bad thing, imo.
As has been discussed elsewhere, requiring explicit casting may incur costs, such as performance overhead due to memory allocation and/or multiple data passes. However, those costs are likely to be incurred regardless. While libraries such as NumPy may not be able to benefit from whole graph optimization, others may be able to combine casts/reductions into a single operation.
but for lower-precision dtypes it will almost never be the right thing to do.
I don't agree. Naive summation techniques may overflow, especially if provided a large set of monotonically increasing positive values. However, for summands of mixed sign, various summation techniques are possible which guard against overflow by using correction terms. So I don't think overflow is guaranteed to be rampant.
Another, potentially left-field, option is to support specifying the output dtype (e.g., output_dtype=None
). By default, the behavior could be to return the input dtype. If a user wants to return an alternative dtype, then s/he can opt-in and the underlying implementations can figure out the best way to meet the request.
The issue here, however, is the specification would be underspecified for mixed-kind input/output dtypes. So one potential constraint could be to require an output dtype be of the same kind and possibly of equal or greater size.
To be clear, what we are discussing here only applies to a subset of statistical reductions--namely, sum
and prod
.
For max
and min
, returning the input dtype is perfectly acceptable.
For mean
, var
, and std
, returning the default floating-point type is the primary practical option, as robust estimation algorithms must apply scaling (particularly for var
and std
) and thus require a dtype capable of handling decimals. As elsewhere, such as in element-wise mathematical functions, how libraries handle integer dtypes while returning a floating-point dtype (i.e., mixed-kind casting) should be implementation-defined.
Thus leaving sum
and prod
which should be consistent with one another; however, if we don't default to returning the input dtype, then we'll have a third category of promotion rules for a particular subset of statistical reductions, which may or may not be desirable.
The new C++ guidance is to infer the intermediate type from the operator - see P2322.
To be clear, what we are discussing here only applies to a subset of statistical reductions--namely,
sum
andprod
.
Yes indeed. There's yet another set, namely the reductions that return bool dtype (any
, all
).
For
mean
,var
, andstd
, returning the default floating-point type is the primary practical option
Agreed. The specification is probably wrong though. It says to return the default floating-point dtype, but it should be preserving dtypes (while accumulation in higher precision as needed):
>>> np.std(np.ones(3, dtype=np.float32)).dtype
dtype('float32')
Another, potentially left-field, option is to support specifying the output dtype (e.g.,
output_dtype=None
).
Not all that left-field. From the numpy.sum
docstring
dtype : dtype, optional
The type of the returned array and of the accumulator in which the
elements are summed. The dtype of `a` is used by default unless `a`
has an integer dtype of less precision than the default platform
integer. In that case, if `a` is signed then the platform integer
is used while if `a` is unsigned then an unsigned integer of the
same precision as the platform integer is used.
The specification is probably wrong though.
You are right. At minimum, should clarify that returning the default floating-point dtype applies when providing integer dtypes.
For dealing with int8, the main use case in deep learning is probably quantization for inference. Intel has a comprehensive library for such usage: https://github.com/intel/lpot
I don't have a formed opinion about special casing integers. I like the consistency of not "up-promoting", but sum
and prod
may just be special and common enough (it is pretty awkward in NumPy currently).
A few (NumPy specific) points, most of which just to read and forget :):
initial
should be used, but NumPy currently does not. I.e. reduce_dtype = (initial.dtype + arr.dtype).dtype
(with the potential special case "up-promotion" after that(?)). (The current behaviour is to ignore initial.dtype
in NumPy. This only works out in the uint64 example because (uint64)-1
happens to do the right thing with integer rollover/overflow.)std
, etc. should promote using common_dtype(arr.dtype, Floating)
(this is currently spelled as np.result_type(arr, 0.)
usually). This is a fairly common pattern and not super specific to reduce-like operations. (Something like Real
instead of Floating
could be more precise, but it means the same thing here.).std
should be the same as mean
. And for mean
it may be nice to ensure alignment with arr.sum() / arr.size
.np.sum()
and np.add.reduce()
could in theory differ, although I am not sure it would be easy to do. (That is to say sum
and prod
can be a bit special.)sum(list(arr))
or reduce(np.add, *arr)
does not give comparable results to np.sum(arr, axis=0)
since we up-promote/cast.sum(uint16_arr)
could also be uint64
(why signed?).IntegerOverflowWarning
warnings (and make it the default).Just a little context from the PyTorch side:
dtype
argument on reductions to control their output explicitly, but if performing reductions over a variety of dtypes (like when writing a library function) it might be hard to use this properly if the goal is to upcast some dtypesThis is also relevant to the trace
function in the linear algebra extension (there are other functions in linear algebra that accept integer inputs, but they all take two arguments, so they use normal type promotion).
It seems trace
was not updated. We should make it consistent with sum
. Currently np.trace
does the same type promotion as np.sum
.
@asmeurer are there plans to update trace sometime soon? Encountering dtype issues in ivy testing.
I opened a new issue to track this https://github.com/data-apis/array-api/issues/493. I would just make the change, but I'm not sure if we should also add a dtype
argument to trace()
to match sum()
and prod()
.
Reductions were added in PR gh-17, based on discussion in gh-10. There was quite a bit of discussion in calls as well around reductions (e.g., which ones to support, returning 0-D arrays and not scalars, naming) but not about casting rules and accepted input dtypes. It turns out that this is pretty inconsistent between libraries. Here's a script that compares
sum
,std
andprod
:And the result of that:
Conclusions
For
sum(int8)
andprod(int8)
there appear to be two options:The TensorFlow docs do note this as the one inconsistency with NumPy: https://www.tensorflow.org/api_docs/python/tf/math/reduce_sum says "Equivalent to np.sum apart the fact that numpy upcast uint8 and int32 to int64 while tensorflow returns the same dtype as the input."
The MXNet docs at https://mxnet.apache.org/versions/master/api/python/docs/api/np/generated/mxnet.np.sum.html#mxnet-np-sum do not clearly say that this is expected, even though those docs do have a list of differences with NumPy (@szha thoughts on this?).
For
std(int8)
there appear to be three options:This is all quite inconsistent, and needs to be considered more carefully for all reductions and dtypes.