Closed carlosgmartin closed 1 year ago
jax.numpy.min
and jax.numpy.max
follow the semantics of numpy.min
and numpy.max
, which do not assume any implicit identity. As you demonstrated, you can specify the identity you'd like to use by passing it to the initial
value of either function.
@jakevdp I see. It's unfortunate that numpy lacks the mathematically correct semantics. I commented on that here.
Are there existing functions like min_value(dtype)
/max_value(dtype)
in jax, that work for all the relevant dtypes?
Thanks for the link – I have to say I agree with the reasons given in the linked issue to not have the proposed behavior by default.
In particular, I think it would be very surprising if e.g. jnp.arange(N).max()
were to return -2147483648
when N = 0
. An error in this case is much more user-friendly, in my opinion, and achieving the alternate behavior is as easy passing a single keyword argument.
Returning the additive identity and multiplicative identity in the case of sums/products over empty arrays seems fine to me, because both are representable values in every dtype (also, x * 0
and x ** 0
are well-defined functions returning those identities).
As far as I know, there is no function built-in to JAX that does what you have in mind. But fortunately, you can define it yourself in a couple lines using the initial
argument.
Sorry, I meant the functions that return the least/greatest element of a given dtype, not the safe_max
/safe_min
.
functions that return the least/greatest element of a given dtype,
There's jnp.iinfo(dtype).min
for integer dtypes, and jnp.finfo(dtype).min
for inexact types, but it seems you already know about those.
Seems like neither of those handles bool
. Does the union of all three cover all cases?
Yes, this should do it:
def min_val(dtype):
return (False if dtype == bool else np.iinfo(dtype).min if jnp.issubdtype(dtype, jnp.integer) else jnp.finfo(dtype).min
I've thought about this more, and I think I've put a finger on why I feel OK with sum
and prod
of an empty array returning their identity, but min
and max
raising an error.
I'll grant your point that any particular dtype
can be viewed as a finite set for which the minimum/maximum element represents the identity of the max
/min
reduction over elements of that set. But for real numbers, such an identity doesn't exist! For example, in the (countably infinite) set of integers, there does not exist any element $e$ such that $min(e, a) = a$ for any $a$ in the set. Likewise, in the (uncountably infinite) set of real numbers, there does not exist any element $e$ such that $max(e, a) = a$ for all $a$ in the set.
It's true that $\infty$ and $-\infty$ would work as identity elements, but infinity is not a member of the set of integers or the set of real numbers!
So if we're treating float32
or int32
as sets in their own right, we could use their min/max elements as identities to max/min reductions. But if we're using them to approximate real arithmetic, we should not expect min
or max
to have an identity. The fact that the identities exist in the particular implementation is just an implementation detail, not something fundamental that we should build our arithmetic APIs upon.
Contrast this with sum
, prod
, any
, all
, and you'll see that for all of those, the identity element exists and is representable in the particular dtype
implementation we're using.
I suspect this is why Python, NumPy, JAX, and probably others return identity elements when empty sequences are passed to sum
, prod
, any
, and all
, but raise errors when empty sequences are passed to min
and max
: it's because the intent of these operations is to approximate the properties of real-valued mathematics, not to reflect the details of the implementation.
Floats could be said to approximate the extended reals (which complete the reals with $\pm \infty$), but your point about integers makes sense. (Unsigned integers, i.e. naturals, do have a minimum element, namely zero, but signed integers don't.) I figured using the minimum representable element of each dtype as the max identity would be the best way to treat dtypes uniformly, while allowing it to be defined for floats in particular (and, say, max for unsigned ints).
Sure, but you're pointing out features of the implementation. Floats and ints (both signed and unsigned) are representations of real-valued arithmetic, for which there is no minimum or maximum element. Treating them as bounded sets focuses on the implementation, rather than what the implementation represents.
You're free to pass the initial
value if you want the behavior to be different, but I don't think the behavior you propose is the appropriate default (and both Python min
/max
and NumPy min
/max
have made the same choice).
The identity element of
max
(join) is the least element (bottom) of the input array's dtype. The identity element ofmin
(meet) is the greatest element (top) of the input array's dtype.Therefore, the
max
andmin
of an empty array should be the least and greatest element of its dtype, respectively. This can be implemented by editingjax._src.numpy.reductions
.Example implementation: