elixir-nx / nx

Multi-dimensional arrays (tensors) and numerical definitions for Elixir
2.55k stars 190 forks source link

Fix argmax/argmin behaviour with NaNs #1499

Closed jonatanklosko closed 1 month ago

jonatanklosko commented 1 month ago

Currently the behaviour of argmax/min for NaNs is inconsistent with reduce_max/min:

iex> Nx.reduce_max(Nx.tensor([1.0, :nan, 2.0]))
#Nx.Tensor<
  f32
  NaN
>
iex> Nx.argmax(Nx.tensor([1.0, :nan, 2.0]))
#Nx.Tensor<
  s64
  2
>

This is the case both for EXLA and BinaryBackend. I changed both such that argmax/argmin to point to a NaN in that case. Torchx already adheres to this behaviour.

I specifically changed the implementation such that MLIR aligns exactly with Jax.

josevalim commented 1 month ago

Note there is this commit: 3d1b594395184e00c1c281b8a813e13fb5c966ac

Does it make sense to have NaN be both the min and the max? Also, can you please run scholar tests on this branch? It does use NaN for sorting quite extensibly.

jonatanklosko commented 1 month ago

Note there is this commit: https://github.com/elixir-nx/nx/commit/3d1b594395184e00c1c281b8a813e13fb5c966ac

I think sorting is different. Since NaNs are not really comparable to anything, they don't have a correct position, so throwing all of them to the end is reasonable.

Does it make sense to have NaN be both the min and the max?

I think so, one way to view it is that in case of NaNs, there is no difference between min and max, since NaNs are not comparable. It's basically the same as trying to take a sum of the tensor, one NaN makes the whole thing NaN.

Also, can you please run scholar tests on this branch?

All tests pass.