Closed polvalente closed 3 weeks ago
I'm still trying to narrow it down, but the issue must be somewhere in the combination of random_bits
and take_along_axis
. For example, if we make this change:
-sort_keys = random_bits(keys[1], shape: tensor.shape)
+sort_keys = randint_split(keys[1], 0, uint32max, shape: tensor.shape)
It seems to work just fine.
@polvalente actually, this also fixes it:
-sort_keys = random_bits(keys[1], shape: tensor.shape)
+sort_keys = random_bits(keys[1], shape: tensor.shape) |> Nx.as_type(:s32)
fixes #1551
In short, what was happening is that we were not testing a specific edge case that resulted in an invalid-shaped tensor for the
sort_keys
sub-input insort_key_val
, that in turn yielded an invalid shuffle when the input is 1D.