jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.6k stars 2.82k forks source link

max/min of empty array should be least/greatest element of its dtype #18661

Closed carlosgmartin closed 1 year ago

carlosgmartin commented 1 year ago

The identity element of max (join) is the least element (bottom) of the input array's dtype. The identity element of min (meet) is the greatest element (top) of the input array's dtype.

Therefore, the max and min of an empty array should be the least and greatest element of its dtype, respectively. This can be implemented by editing jax._src.numpy.reductions.

Example implementation:

from jax import numpy as jnp

def min_value(dtype):
    try:
        return jnp.iinfo(dtype).min
    except ValueError:
        pass

    try:
        return jnp.finfo(dtype).min
    except ValueError:
        pass

    if dtype == bool:
        return False

def max_value(dtype):
    try:
        return jnp.iinfo(dtype).max
    except ValueError:
        pass

    try:
        return jnp.finfo(dtype).max
    except ValueError:
        pass

    if dtype == bool:
        return True

def safe_max(a, initial=None):
    if initial is None:
        initial = min_value(a.dtype)
    return a.max(initial=initial)

def safe_min(a, initial=None):
    if initial is None:
        initial = max_value(a.dtype)
    return a.min(initial=initial)

def main():
    for dtype in [bool, jnp.uint8, jnp.float32, jnp.int32]:
        empty = jnp.array([], dtype)
        assert empty.sum() == 0
        assert empty.prod() == 1

        try:
            assert empty.max() == min_value(empty.dtype)
        except ValueError as exception:
            print(repr(exception))

        try:
            assert empty.min() == max_value(empty.dtype)
        except ValueError as exception:
            print(repr(exception))

        assert safe_max(empty) == min_value(empty.dtype)
        assert safe_min(empty) == max_value(empty.dtype)

if __name__ == "__main__":
    main()
ValueError('zero-size array to reduction operation max which has no identity')
ValueError('zero-size array to reduction operation min which has no identity')
ValueError('zero-size array to reduction operation max which has no identity')
ValueError('zero-size array to reduction operation min which has no identity')
ValueError('zero-size array to reduction operation max which has no identity')
ValueError('zero-size array to reduction operation min which has no identity')
ValueError('zero-size array to reduction operation max which has no identity')
ValueError('zero-size array to reduction operation min which has no identity')
jakevdp commented 1 year ago

jax.numpy.min and jax.numpy.max follow the semantics of numpy.min and numpy.max, which do not assume any implicit identity. As you demonstrated, you can specify the identity you'd like to use by passing it to the initial value of either function.

carlosgmartin commented 1 year ago

@jakevdp I see. It's unfortunate that numpy lacks the mathematically correct semantics. I commented on that here.

Are there existing functions like min_value(dtype)/max_value(dtype) in jax, that work for all the relevant dtypes?

jakevdp commented 1 year ago

Thanks for the link – I have to say I agree with the reasons given in the linked issue to not have the proposed behavior by default.

In particular, I think it would be very surprising if e.g. jnp.arange(N).max() were to return -2147483648 when N = 0. An error in this case is much more user-friendly, in my opinion, and achieving the alternate behavior is as easy passing a single keyword argument.

Returning the additive identity and multiplicative identity in the case of sums/products over empty arrays seems fine to me, because both are representable values in every dtype (also, x * 0 and x ** 0 are well-defined functions returning those identities).

As far as I know, there is no function built-in to JAX that does what you have in mind. But fortunately, you can define it yourself in a couple lines using the initial argument.

carlosgmartin commented 1 year ago

Sorry, I meant the functions that return the least/greatest element of a given dtype, not the safe_max/safe_min.

jakevdp commented 1 year ago

functions that return the least/greatest element of a given dtype,

There's jnp.iinfo(dtype).min for integer dtypes, and jnp.finfo(dtype).min for inexact types, but it seems you already know about those.

carlosgmartin commented 1 year ago

Seems like neither of those handles bool. Does the union of all three cover all cases?

jakevdp commented 1 year ago

Yes, this should do it:

def min_val(dtype):
  return (False if dtype == bool else np.iinfo(dtype).min if jnp.issubdtype(dtype, jnp.integer) else jnp.finfo(dtype).min
jakevdp commented 1 year ago

I've thought about this more, and I think I've put a finger on why I feel OK with sum and prod of an empty array returning their identity, but min and max raising an error.

I'll grant your point that any particular dtype can be viewed as a finite set for which the minimum/maximum element represents the identity of the max/min reduction over elements of that set. But for real numbers, such an identity doesn't exist! For example, in the (countably infinite) set of integers, there does not exist any element $e$ such that $min(e, a) = a$ for any $a$ in the set. Likewise, in the (uncountably infinite) set of real numbers, there does not exist any element $e$ such that $max(e, a) = a$ for all $a$ in the set.

It's true that $\infty$ and $-\infty$ would work as identity elements, but infinity is not a member of the set of integers or the set of real numbers!

So if we're treating float32 or int32 as sets in their own right, we could use their min/max elements as identities to max/min reductions. But if we're using them to approximate real arithmetic, we should not expect min or max to have an identity. The fact that the identities exist in the particular implementation is just an implementation detail, not something fundamental that we should build our arithmetic APIs upon.

Contrast this with sum, prod, any, all, and you'll see that for all of those, the identity element exists and is representable in the particular dtype implementation we're using.

I suspect this is why Python, NumPy, JAX, and probably others return identity elements when empty sequences are passed to sum, prod, any, and all, but raise errors when empty sequences are passed to min and max: it's because the intent of these operations is to approximate the properties of real-valued mathematics, not to reflect the details of the implementation.

carlosgmartin commented 1 year ago

Floats could be said to approximate the extended reals (which complete the reals with $\pm \infty$), but your point about integers makes sense. (Unsigned integers, i.e. naturals, do have a minimum element, namely zero, but signed integers don't.) I figured using the minimum representable element of each dtype as the max identity would be the best way to treat dtypes uniformly, while allowing it to be defined for floats in particular (and, say, max for unsigned ints).

jakevdp commented 1 year ago

Sure, but you're pointing out features of the implementation. Floats and ints (both signed and unsigned) are representations of real-valued arithmetic, for which there is no minimum or maximum element. Treating them as bounded sets focuses on the implementation, rather than what the implementation represents.

You're free to pass the initial value if you want the behavior to be different, but I don't think the behavior you propose is the appropriate default (and both Python min/max and NumPy min/max have made the same choice).