data-apis / array-api

RFC document, tooling and other content related to the array API standard
https://data-apis.github.io/array-api/latest/
MIT License
215 stars 45 forks source link

Specify casting rules and accepted input dtypes for reductions better #202

Closed rgommers closed 3 years ago

rgommers commented 3 years ago

Reductions were added in PR gh-17, based on discussion in gh-10. There was quite a bit of discussion in calls as well around reductions (e.g., which ones to support, returning 0-D arrays and not scalars, naming) but not about casting rules and accepted input dtypes. It turns out that this is pretty inconsistent between libraries. Here's a script that compares sum, std and prod:

import numpy as np
import dask.array as da
import torch
import tensorflow as tf
import jax.numpy as jnp
import mxnet
try:
    import cupy as cp
except ImportError:
    # CuPy is GPU-only, so may not be available
    cp = None

def ones(mod, shape):
    # Create a (3, 2)-shaped array of int8 1's
    if mod in (da, mxnet):
        x = mod.ones(shape, dtype=np.int8)  # MXNet doesn't have dtype literals
    else:
        x = mod.ones(shape, dtype=mod.int8)

    return x

def sum(mod, x):
    if mod == tf:
        y = tf.math.reduce_sum(x)
    else:
        y = mod.sum(x)

    return y

def std(mod, x):
    if mod == tf:
        y = tf.math.reduce_std(x)
    elif mod == mxnet:
        y = mxnet.std(x)
    else:
        y = mod.std(x)

    return y

def prod(mod, x):
    if mod == tf:
        y = tf.math.reduce_prod(x)
    else:
        y = mod.prod(x)

    return y

libraries = {
    'numpy': np,
    'pytorch': torch,
    'mxnet': mxnet.np,
    'dask': da,
    'tensorflow': tf,
    'jax': jnp,
}

if cp is not None:
    libraries['cupy'] = cp

results = libraries.copy()

# A separate call to get rid of TF and JAX noise:
shape = (3, 2)
_ = sum(tf, ones(tf, shape))
_ = sum(jnp, ones(jnp, shape))

print("\nsum(int8_array)\n" + "-"*15)
for name, mod in libraries.items():
    dtype = sum(mod, ones(mod, shape)).dtype
    print(f'{name}: {dtype}')

print("\nstd(int8_array)\n" + "-"*15)
for name, mod in libraries.items():
    try:
        dtype = std(mod, ones(mod, shape)).dtype
        print(f'{name}: {dtype}')
    except Exception as e:
        print(f'{name}: {repr(e)}')

print("\nprod(int8_array)\n" + "-"*16)
for name, mod in libraries.items():
    try:
        dtype = prod(mod, ones(mod, shape)).dtype
        print(f'{name}: {dtype}')
    except Exception as e:
        print(f'{name}: {repr(e)}')

And the result of that:

sum(int8_array)
---------------
numpy: int64
pytorch: torch.int64
mxnet: int8
dask: int64
tensorflow: <dtype: 'int8'>
jax: int32
cupy: int64

std(int8_array)
---------------
numpy: float64
pytorch: RuntimeError('std only supports floating-point dtypes')
mxnet: int8
dask: float64
tensorflow: TypeError('Input must be either real or complex')
jax: float32
cupy: float64

prod(int8_array)
----------------
numpy: int64
pytorch: torch.int64
mxnet: int8
dask: int64
tensorflow: <dtype: 'int8'>
jax: int32
cupy: int64

Conclusions

For sum(int8) and prod(int8) there appear to be two options:

  1. Return the default integer dtype (i.e. upcast the input). This is what NumPy, PyTorch, JAX and CuPy do (and Dask behavior is dictated by the underlying NumPy/CuPy arrays).
  2. Keep the input dtype. This is what TensorFlow and MXNet do. It is also what the spec currently says.

The TensorFlow docs do note this as the one inconsistency with NumPy: https://www.tensorflow.org/api_docs/python/tf/math/reduce_sum says "Equivalent to np.sum apart the fact that numpy upcast uint8 and int32 to int64 while tensorflow returns the same dtype as the input."

The MXNet docs at https://mxnet.apache.org/versions/master/api/python/docs/api/np/generated/mxnet.np.sum.html#mxnet-np-sum do not clearly say that this is expected, even though those docs do have a list of differences with NumPy (@szha thoughts on this?).

For std(int8) there appear to be three options:

  1. Return the default floating-point dtype
  2. Raise an exception. For TensorFlow this is consistent with its design, because it doesn't do int-to-float casting. Why PyTorch raises is unclear, probably for historical reasons (it has int-to-float casting now, but didn't used to have it).
  3. Keep the input dtype. This is what MXNet does. It's a consistent design rule, but clearly doesn't make too much sense - I'd expect this one to be a mistake.

This is all quite inconsistent, and needs to be considered more carefully for all reductions and dtypes.

leofang commented 3 years ago

I feel part of this issue is related to the fact that different implementations handle the intermediate type during reduction differently. On the C++ side this is also a known problem and there is an ongoing proposal to fix this ambiguity (and all the associated issues) by requiring the intermediate type to be either the dtype of the initial value (if there's one) or the input iterator's value type (https://wg21.link/P0571). I would suggest to follow the C++ behavior if possible (assuming if the proposal would be accepted) as most libraries have a C++ implementation under the hood. The net effect is likely leaning toward keeping the input dtype.

Pinging the proposal author @brycelelbach for awareness (maybe he could comment on potential pitfalls or challenges).

On the CuPy side, I see no reason to not follow this behavior except that we want to be NumPy compliant in the main namespace.

rgommers commented 3 years ago

Interesting, thanks @leofang. That proposal goes into a lot of depth about the intermediary type, which matters for implementers but is something we should be agnostic about here imho. The only thing that matters is the output dtype, because that's user-observable behavior. If someone wants to write an implementation where the intermediate type is long double for all input dtypes, and then cast the final result back to the correct output dtype, that should be perfectly fine.

I'm very unsure about what the output dtype should be. Preserving input dtype sounds nice in theory, but I expect the TF/MXNet behavior to be a foot gun in practice. Take for example uint8 and uint16 - the main use case for those are image processing. Any practically relevant sum(uint8_image) is going to overflow, and many uint16 images will too. So it seems like a very impractical choice.

Reductions in numpy have an initial keyword (see, e.g. docs for sum), which seems necessary to force an upcast in case the behavior would be "keep input dtype", but is hardly used with numpy code because the default behavior is fine. TF reduce_sum doesn't have this. I couldn't find an issue about this on the TF issue tracker so quickly.

Side note, NumPy does have something weird going on as well, it's happy to use scalar negative values in unsigned integer reductions; unclear to me why:

>>> image = np.ones((100, 2), dtype=np.uint8)
>>> np.sum(image, axis=0)
array([100, 100], dtype=uint64)
>>> np.sum(image, axis=0, initial=1)
array([101, 101], dtype=uint64)
>>> np.sum(image, axis=0, initial=-1)
array([99, 99], dtype=uint64)
>>> np.sum(image, axis=0, initial=np.array(-1))
...
TypeError: Cannot cast scalar from dtype('int64') to dtype('uint64') according to the rule 'safe'
rgommers commented 3 years ago

Searching for "reduce_sum overflow" is hard, cause "Stack Overflow". But this shows exactly what I thought TF would struggle with: https://github.com/tensorflow/tensorflow/commit/23fde233bf3210759b5a4453bc39101df9c86d0c, https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/reduction_ops.h#L60.

// Specialization for which we do the reduction in IntermediateType to
// avoid integer overflow.
#define CASTING_SPECIALIZATION(ScalarType, IntermediateType)
...
CASTING_SPECIALIZATION(uint8, uint64);
CASTING_SPECIALIZATION(uint16, uint64);
<etc>

And same for floating point:

// Specialization for BF16 Reducer to fix accuracy.
// TODO: All BF16 reducers should have specializations to fix accuracy.
rgommers commented 3 years ago

JAX discussion: https://github.com/google/jax/issues/3154. That leans towards making a change from its current behavior to preserving input dtype.

The discussion there does make sense for, e.g., float32. The trouble is that for 32-bit dtypes, preserving input dtypes seems clearly better, but for lower-precision dtypes it will almost never be the right thing to do.

EDIT: copying the example from that issue to show JAX behavior for all dtypes:

In [1]: import jax.numpy as jnp
   ...: from jax.test_util import dtypes
   ...: from jax import config; config.update('jax_enable_x64', True)
   ...: for dtype in dtypes.all:
   ...:     print(dtype.__name__, "->", jnp.zeros(2, dtype).sum().dtype)
   ...:     
bfloat16 -> bfloat16
float16 -> float16
float32 -> float32
float64 -> float64
int8 -> int64
int16 -> int64
int32 -> int64
int64 -> int64
uint8 -> uint64
uint16 -> uint64
uint32 -> uint64
uint64 -> uint64
complex64 -> complex64
complex128 -> complex128
bool_ -> int64

Also, JAX has a global setting to switch default dtype to 64-bit??

kgryte commented 3 years ago

Preserving input dtype sounds nice in theory, but I expect the TF/MXNet behavior to be a foot gun in practice.

Personally, I think pushing considerations of overflow to userland is fine. If a user has reason to be concerned about overflow during summation or multiplication, then explicitly casting an array to dtype capable of handling larger values without overflow should be okay. That a user would be forced to explicitly think about desired dtypes is not necessarily a bad thing, imo.

As has been discussed elsewhere, requiring explicit casting may incur costs, such as performance overhead due to memory allocation and/or multiple data passes. However, those costs are likely to be incurred regardless. While libraries such as NumPy may not be able to benefit from whole graph optimization, others may be able to combine casts/reductions into a single operation.

but for lower-precision dtypes it will almost never be the right thing to do.

I don't agree. Naive summation techniques may overflow, especially if provided a large set of monotonically increasing positive values. However, for summands of mixed sign, various summation techniques are possible which guard against overflow by using correction terms. So I don't think overflow is guaranteed to be rampant.

kgryte commented 3 years ago

Another, potentially left-field, option is to support specifying the output dtype (e.g., output_dtype=None). By default, the behavior could be to return the input dtype. If a user wants to return an alternative dtype, then s/he can opt-in and the underlying implementations can figure out the best way to meet the request.

The issue here, however, is the specification would be underspecified for mixed-kind input/output dtypes. So one potential constraint could be to require an output dtype be of the same kind and possibly of equal or greater size.

kgryte commented 3 years ago

To be clear, what we are discussing here only applies to a subset of statistical reductions--namely, sum and prod.

For max and min, returning the input dtype is perfectly acceptable.

For mean, var, and std, returning the default floating-point type is the primary practical option, as robust estimation algorithms must apply scaling (particularly for var and std) and thus require a dtype capable of handling decimals. As elsewhere, such as in element-wise mathematical functions, how libraries handle integer dtypes while returning a floating-point dtype (i.e., mixed-kind casting) should be implementation-defined.

Thus leaving sum and prod which should be consistent with one another; however, if we don't default to returning the input dtype, then we'll have a third category of promotion rules for a particular subset of statistical reductions, which may or may not be desirable.

brycelelbach commented 3 years ago

The new C++ guidance is to infer the intermediate type from the operator - see P2322.

rgommers commented 3 years ago

To be clear, what we are discussing here only applies to a subset of statistical reductions--namely, sum and prod.

Yes indeed. There's yet another set, namely the reductions that return bool dtype (any, all).

For mean, var, and std, returning the default floating-point type is the primary practical option

Agreed. The specification is probably wrong though. It says to return the default floating-point dtype, but it should be preserving dtypes (while accumulation in higher precision as needed):

>>> np.std(np.ones(3, dtype=np.float32)).dtype
dtype('float32')
rgommers commented 3 years ago

Another, potentially left-field, option is to support specifying the output dtype (e.g., output_dtype=None).

Not all that left-field. From the numpy.sum docstring

dtype : dtype, optional
    The type of the returned array and of the accumulator in which the
    elements are summed.  The dtype of `a` is used by default unless `a`
    has an integer dtype of less precision than the default platform
    integer.  In that case, if `a` is signed then the platform integer
    is used while if `a` is unsigned then an unsigned integer of the
    same precision as the platform integer is used.
kgryte commented 3 years ago

The specification is probably wrong though.

You are right. At minimum, should clarify that returning the default floating-point dtype applies when providing integer dtypes.

szha commented 3 years ago

For dealing with int8, the main use case in deep learning is probably quantization for inference. Intel has a comprehensive library for such usage: https://github.com/intel/lpot

seberg commented 3 years ago

I don't have a formed opinion about special casing integers. I like the consistency of not "up-promoting", but sum and prod may just be special and common enough (it is pretty awkward in NumPy currently).

A few (NumPy specific) points, most of which just to read and forget :):

mruberry commented 3 years ago

Just a little context from the PyTorch side:

asmeurer commented 3 years ago

This is also relevant to the trace function in the linear algebra extension (there are other functions in linear algebra that accept integer inputs, but they all take two arguments, so they use normal type promotion).

kgryte commented 3 years ago

This issue should be resolved by gh-238 and gh-260. Closing out...

asmeurer commented 3 years ago

It seems trace was not updated. We should make it consistent with sum. Currently np.trace does the same type promotion as np.sum.

simonetgordon commented 2 years ago

@asmeurer are there plans to update trace sometime soon? Encountering dtype issues in ivy testing.

asmeurer commented 2 years ago

I opened a new issue to track this https://github.com/data-apis/array-api/issues/493. I would just make the change, but I'm not sure if we should also add a dtype argument to trace() to match sum() and prod().