elixir-nx / nx

Multi-dimensional arrays (tensors) and numerical definitions for Elixir
2.66k stars 194 forks source link

fix: Nx.Random.shuffle repeating a single value in certain cases on GPU #1552

Closed polvalente closed 3 weeks ago

polvalente commented 3 weeks ago

fixes #1551

In short, what was happening is that we were not testing a specific edge case that resulted in an invalid-shaped tensor for the sort_keys sub-input in sort_key_val, that in turn yielded an invalid shuffle when the input is 1D.

jonatanklosko commented 3 weeks ago

I'm still trying to narrow it down, but the issue must be somewhere in the combination of random_bits and take_along_axis. For example, if we make this change:

-sort_keys = random_bits(keys[1], shape: tensor.shape)
+sort_keys = randint_split(keys[1], 0, uint32max, shape: tensor.shape)

It seems to work just fine.

jonatanklosko commented 3 weeks ago

@polvalente actually, this also fixes it:

-sort_keys = random_bits(keys[1], shape: tensor.shape)
+sort_keys = random_bits(keys[1], shape: tensor.shape) |> Nx.as_type(:s32)