Closed VictorT42 closed 3 weeks ago
Given that this works both on EXLA on CPU and on Nx.BinaryBackend, everything points towards some upstream bug on XLA
@jonatanklosko any ideas on what we should report there?
This is the MLIR module for &Nx.Random.shuffle/2
:
module {
func.func public @main(%arg0: tensor<2xui32>, %arg1: tensor<1000000xi32>) -> (tensor<1000000xi32>, tensor<2xui32>) {
%c = stablehlo.constant dense<0> : tensor<i32>
%0 = stablehlo.iota dim = 0 : tensor<1000000xi32>
%1 = stablehlo.reshape %0 : (tensor<1000000xi32>) -> tensor<1000000xi32>
%2:3 = stablehlo.while(%iterArg = %c, %iterArg_0 = %1, %iterArg_1 = %arg0) : tensor<i32>, tensor<1000000xi32>, tensor<2xui32>
cond {
%c_2 = stablehlo.constant dense<2> : tensor<ui32>
%4 = stablehlo.convert %iterArg : (tensor<i32>) -> tensor<i64>
%5 = stablehlo.convert %c_2 : (tensor<ui32>) -> tensor<i64>
%6 = stablehlo.compare LT, %4, %5, NOTYPE : (tensor<i64>, tensor<i64>) -> tensor<i1>
stablehlo.return %6 : tensor<i1>
} do {
%c_2 = stablehlo.constant dense<1> : tensor<i32>
%4 = stablehlo.add %c_2, %iterArg : tensor<i32>
%c_3 = stablehlo.constant dense<1> : tensor<ui32>
%5 = stablehlo.iota dim = 0 : tensor<1000000xui32>
%6 = stablehlo.reshape %5 : (tensor<1000000xui32>) -> tensor<2x500000xui32>
%7 = stablehlo.slice %6 [0:1, 0:500000] : (tensor<2x500000xui32>) -> tensor<1x500000xui32>
%8 = stablehlo.reshape %7 : (tensor<1x500000xui32>) -> tensor<500000xui32>
%9 = stablehlo.iota dim = 0 : tensor<4xui32>
%10 = stablehlo.reshape %9 : (tensor<4xui32>) -> tensor<2x2xui32>
%11 = stablehlo.slice %10 [0:1, 0:2] : (tensor<2x2xui32>) -> tensor<1x2xui32>
%12 = stablehlo.reshape %11 : (tensor<1x2xui32>) -> tensor<2xui32>
%13 = stablehlo.slice %iterArg_1 [0:1] : (tensor<2xui32>) -> tensor<1xui32>
%14 = stablehlo.reshape %13 : (tensor<1xui32>) -> tensor<ui32>
%15 = stablehlo.broadcast_in_dim %14, dims = [] : (tensor<ui32>) -> tensor<2xui32>
%16 = stablehlo.add %12, %15 : tensor<2xui32>
%17 = stablehlo.slice %10 [1:2, 0:2] : (tensor<2x2xui32>) -> tensor<1x2xui32>
%18 = stablehlo.reshape %17 : (tensor<1x2xui32>) -> tensor<2xui32>
%19 = stablehlo.slice %iterArg_1 [1:2] : (tensor<2xui32>) -> tensor<1xui32>
%20 = stablehlo.reshape %19 : (tensor<1xui32>) -> tensor<ui32>
%21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor<ui32>) -> tensor<2xui32>
%22 = stablehlo.add %18, %21 : tensor<2xui32>
%23 = stablehlo.xor %14, %20 : tensor<ui32>
%c_4 = stablehlo.constant dense<466688986> : tensor<i32>
%24 = stablehlo.convert %23 : (tensor<ui32>) -> tensor<i64>
%25 = stablehlo.convert %c_4 : (tensor<i32>) -> tensor<i64>
%26 = stablehlo.xor %24, %25 : tensor<i64>
%27 = stablehlo.convert %26 : (tensor<i64>) -> tensor<ui32>
%c_5 = stablehlo.constant dense<[13, 15, 26, 6]> : tensor<4xui8>
%c_6 = stablehlo.constant dense<[17, 29, 16, 24]> : tensor<4xui8>
%28:8 = stablehlo.while(%iterArg_9 = %c_3, %iterArg_10 = %16, %iterArg_11 = %22, %iterArg_12 = %20, %iterArg_13 = %27, %iterArg_14 = %14, %iterArg_15 = %c_5, %iterArg_16 = %c_6) : tensor<ui32>, tensor<2xui32>, tensor<2xui32>, tensor<ui32>, tensor<ui32>, tensor<ui32>, tensor<4xui8>, tensor<4xui8>
cond {
%c_17 = stablehlo.constant dense<6> : tensor<i32>
%53 = stablehlo.convert %iterArg_9 : (tensor<ui32>) -> tensor<i64>
%54 = stablehlo.convert %c_17 : (tensor<i32>) -> tensor<i64>
%55 = stablehlo.compare LT, %53, %54, NOTYPE : (tensor<i64>, tensor<i64>) -> tensor<i1>
stablehlo.return %55 : tensor<i1>
} do {
%c_17 = stablehlo.constant dense<1> : tensor<ui32>
%53 = stablehlo.add %c_17, %iterArg_9 : tensor<ui32>
%c_18 = stablehlo.constant dense<0> : tensor<i32>
%54:4 = stablehlo.while(%iterArg_19 = %c_18, %iterArg_20 = %iterArg_15, %iterArg_21 = %iterArg_10, %iterArg_22 = %iterArg_11) : tensor<i32>, tensor<4xui8>, tensor<2xui32>, tensor<2xui32>
cond {
%c_23 = stablehlo.constant dense<3> : tensor<i32>
%60 = stablehlo.compare LE, %iterArg_19, %c_23, NOTYPE : (tensor<i32>, tensor<i32>) -> tensor<i1>
stablehlo.return %60 : tensor<i1>
} do {
%c_23 = stablehlo.constant dense<1> : tensor<i32>
%60 = stablehlo.add %c_23, %iterArg_19 : tensor<i32>
%61 = stablehlo.add %iterArg_21, %iterArg_22 : tensor<2xui32>
%62 = stablehlo.dynamic_slice %iterArg_20, %iterArg_19, sizes = [1] : (tensor<4xui8>, tensor<i32>) -> tensor<1xui8>
%63 = stablehlo.reshape %62 : (tensor<1xui8>) -> tensor<ui8>
%64 = stablehlo.convert %63 : (tensor<ui8>) -> tensor<ui32>
%65 = stablehlo.broadcast_in_dim %64, dims = [] : (tensor<ui32>) -> tensor<2xui32>
%66 = stablehlo.shift_left %iterArg_22, %65 : tensor<2xui32>
%c_24 = stablehlo.constant dense<32> : tensor<ui32>
%67 = stablehlo.convert %63 : (tensor<ui8>) -> tensor<ui32>
%68 = stablehlo.subtract %c_24, %67 : tensor<ui32>
%69 = stablehlo.broadcast_in_dim %68, dims = [] : (tensor<ui32>) -> tensor<2xui32>
%70 = stablehlo.shift_right_logical %iterArg_22, %69 : tensor<2xui32>
%71 = stablehlo.or %66, %70 : tensor<2xui32>
%72 = stablehlo.xor %71, %61 : tensor<2xui32>
stablehlo.return %60, %iterArg_20, %61, %72 : tensor<i32>, tensor<4xui8>, tensor<2xui32>, tensor<2xui32>
}
%55 = stablehlo.broadcast_in_dim %iterArg_12, dims = [] : (tensor<ui32>) -> tensor<2xui32>
%56 = stablehlo.add %55, %54#2 : tensor<2xui32>
%57 = stablehlo.add %iterArg_13, %iterArg_9 : tensor<ui32>
%58 = stablehlo.broadcast_in_dim %57, dims = [] : (tensor<ui32>) -> tensor<2xui32>
%59 = stablehlo.add %58, %54#3 : tensor<2xui32>
stablehlo.return %53, %56, %59, %iterArg_13, %iterArg_14, %iterArg_12, %iterArg_16, %iterArg_15 : tensor<ui32>, tensor<2xui32>, tensor<2xui32>, tensor<ui32>, tensor<ui32>, tensor<ui32>, tensor<4xui8>, tensor<4xui8>
}
%29 = stablehlo.concatenate %28#1, %28#2, dim = 0 : (tensor<2xui32>, tensor<2xui32>) -> tensor<4xui32>
%30 = stablehlo.reshape %29 : (tensor<4xui32>) -> tensor<2x2xui32>
%31 = stablehlo.slice %30 [1:2, 0:1] : (tensor<2x2xui32>) -> tensor<1x1xui32>
%32 = stablehlo.reshape %31 : (tensor<1x1xui32>) -> tensor<ui32>
%33 = stablehlo.broadcast_in_dim %32, dims = [] : (tensor<ui32>) -> tensor<500000xui32>
%34 = stablehlo.add %8, %33 : tensor<500000xui32>
%35 = stablehlo.slice %6 [1:2, 0:500000] : (tensor<2x500000xui32>) -> tensor<1x500000xui32>
%36 = stablehlo.reshape %35 : (tensor<1x500000xui32>) -> tensor<500000xui32>
%37 = stablehlo.slice %30 [1:2, 1:2] : (tensor<2x2xui32>) -> tensor<1x1xui32>
%38 = stablehlo.reshape %37 : (tensor<1x1xui32>) -> tensor<ui32>
%39 = stablehlo.broadcast_in_dim %38, dims = [] : (tensor<ui32>) -> tensor<500000xui32>
%40 = stablehlo.add %36, %39 : tensor<500000xui32>
%41 = stablehlo.xor %32, %38 : tensor<ui32>
%42 = stablehlo.convert %41 : (tensor<ui32>) -> tensor<i64>
%43 = stablehlo.convert %c_4 : (tensor<i32>) -> tensor<i64>
%44 = stablehlo.xor %42, %43 : tensor<i64>
%45 = stablehlo.convert %44 : (tensor<i64>) -> tensor<ui32>
%c_7 = stablehlo.constant dense<[13, 15, 26, 6]> : tensor<4xui8>
%c_8 = stablehlo.constant dense<[17, 29, 16, 24]> : tensor<4xui8>
%46:8 = stablehlo.while(%iterArg_9 = %c_3, %iterArg_10 = %34, %iterArg_11 = %40, %iterArg_12 = %38, %iterArg_13 = %45, %iterArg_14 = %32, %iterArg_15 = %c_7, %iterArg_16 = %c_8) : tensor<ui32>, tensor<500000xui32>, tensor<500000xui32>, tensor<ui32>, tensor<ui32>, tensor<ui32>, tensor<4xui8>, tensor<4xui8>
cond {
%c_17 = stablehlo.constant dense<6> : tensor<i32>
%53 = stablehlo.convert %iterArg_9 : (tensor<ui32>) -> tensor<i64>
%54 = stablehlo.convert %c_17 : (tensor<i32>) -> tensor<i64>
%55 = stablehlo.compare LT, %53, %54, NOTYPE : (tensor<i64>, tensor<i64>) -> tensor<i1>
stablehlo.return %55 : tensor<i1>
} do {
%c_17 = stablehlo.constant dense<1> : tensor<ui32>
%53 = stablehlo.add %c_17, %iterArg_9 : tensor<ui32>
%c_18 = stablehlo.constant dense<0> : tensor<i32>
%54:4 = stablehlo.while(%iterArg_19 = %c_18, %iterArg_20 = %iterArg_15, %iterArg_21 = %iterArg_10, %iterArg_22 = %iterArg_11) : tensor<i32>, tensor<4xui8>, tensor<500000xui32>, tensor<500000xui32>
cond {
%c_23 = stablehlo.constant dense<3> : tensor<i32>
%60 = stablehlo.compare LE, %iterArg_19, %c_23, NOTYPE : (tensor<i32>, tensor<i32>) -> tensor<i1>
stablehlo.return %60 : tensor<i1>
} do {
%c_23 = stablehlo.constant dense<1> : tensor<i32>
%60 = stablehlo.add %c_23, %iterArg_19 : tensor<i32>
%61 = stablehlo.add %iterArg_21, %iterArg_22 : tensor<500000xui32>
%62 = stablehlo.dynamic_slice %iterArg_20, %iterArg_19, sizes = [1] : (tensor<4xui8>, tensor<i32>) -> tensor<1xui8>
%63 = stablehlo.reshape %62 : (tensor<1xui8>) -> tensor<ui8>
%64 = stablehlo.convert %63 : (tensor<ui8>) -> tensor<ui32>
%65 = stablehlo.broadcast_in_dim %64, dims = [] : (tensor<ui32>) -> tensor<500000xui32>
%66 = stablehlo.shift_left %iterArg_22, %65 : tensor<500000xui32>
%c_24 = stablehlo.constant dense<32> : tensor<ui32>
%67 = stablehlo.convert %63 : (tensor<ui8>) -> tensor<ui32>
%68 = stablehlo.subtract %c_24, %67 : tensor<ui32>
%69 = stablehlo.broadcast_in_dim %68, dims = [] : (tensor<ui32>) -> tensor<500000xui32>
%70 = stablehlo.shift_right_logical %iterArg_22, %69 : tensor<500000xui32>
%71 = stablehlo.or %66, %70 : tensor<500000xui32>
%72 = stablehlo.xor %71, %61 : tensor<500000xui32>
stablehlo.return %60, %iterArg_20, %61, %72 : tensor<i32>, tensor<4xui8>, tensor<500000xui32>, tensor<500000xui32>
}
%55 = stablehlo.broadcast_in_dim %iterArg_12, dims = [] : (tensor<ui32>) -> tensor<500000xui32>
%56 = stablehlo.add %55, %54#2 : tensor<500000xui32>
%57 = stablehlo.add %iterArg_13, %iterArg_9 : tensor<ui32>
%58 = stablehlo.broadcast_in_dim %57, dims = [] : (tensor<ui32>) -> tensor<500000xui32>
%59 = stablehlo.add %58, %54#3 : tensor<500000xui32>
stablehlo.return %53, %56, %59, %iterArg_13, %iterArg_14, %iterArg_12, %iterArg_16, %iterArg_15 : tensor<ui32>, tensor<500000xui32>, tensor<500000xui32>, tensor<ui32>, tensor<ui32>, tensor<ui32>, tensor<4xui8>, tensor<4xui8>
}
%47 = stablehlo.concatenate %46#1, %46#2, dim = 0 : (tensor<500000xui32>, tensor<500000xui32>) -> tensor<1000000xui32>
%48 = stablehlo.iota dim = 0 : tensor<1000000xi32>
%49:2 = "stablehlo.sort"(%47, %48) <{dimension = 0 : i64, is_stable = false}> ({
^bb0(%arg2: tensor<ui32>, %arg3: tensor<ui32>, %arg4: tensor<i32>, %arg5: tensor<i32>):
%53 = stablehlo.compare LT, %arg2, %arg3, NOTYPE : (tensor<ui32>, tensor<ui32>) -> tensor<i1>
stablehlo.return %53 : tensor<i1>
}) : (tensor<1000000xui32>, tensor<1000000xi32>) -> (tensor<1000000xui32>, tensor<1000000xi32>)
%50 = func.call @main_optional_34050(%iterArg_0, %49#1) : (tensor<1000000xi32>, tensor<1000000xi32>) -> tensor<1000000xi32>
%51 = stablehlo.slice %30 [0:1, 0:2] : (tensor<2x2xui32>) -> tensor<1x2xui32>
%52 = stablehlo.reshape %51 : (tensor<1x2xui32>) -> tensor<2xui32>
stablehlo.return %4, %50, %52 : tensor<i32>, tensor<1000000xi32>, tensor<2xui32>
}
%3 = "stablehlo.gather"(%arg1, %2#1) <{dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1>}> : (tensor<1000000xi32>, tensor<1000000xi32>) -> tensor<1000000xi32>
return %3, %2#2 : tensor<1000000xi32>, tensor<2xui32>
}
func.func nested @main_optional_34050(%arg0: tensor<1000000xi32>, %arg1: tensor<1000000xi32>) -> tensor<1000000xi32> {
%0 = stablehlo.reshape %arg1 : (tensor<1000000xi32>) -> tensor<1000000x1xi32>
%1 = "stablehlo.gather"(%arg0, %0) <{dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1>}> : (tensor<1000000xi32>, tensor<1000000x1xi32>) -> tensor<1000000xi32>
return %1 : tensor<1000000xi32>
}
}
I had a realization that maybe we had gotten some edge case wrong in the implementation.
Most of our implementation for shuffle
is a line-by-line translation of the Jax implementation.
However, Jax has a variadic sort function that sorts N tensors based on the sorting of the first one, and we don't have that, so we emulate it with argsort + take. This in turn has a really specific edge-case when the input is 1D, which the linked PR should fix.
This works on my GPU now:
iex(1)> Nx.default_backend({EXLA.Backend, client: :cuda})
{Nx.BinaryBackend, []}
iex(2)> key = Nx.Random.key(1)
#Nx.Tensor<
u32[2]
EXLA.Backend<host:0, 0.3207131337.2747400256.146062>
[0, 1]
>
iex(3)> input = Nx.iota({1_000_000})
#Nx.Tensor<
s32[1000000]
EXLA.Backend<cuda:0, 0.3207131337.2747400258.146736>
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, ...]
>
iex(4)> {output, _new_key} = Nx.Random.shuffle(key, input)
{#Nx.Tensor<
s32[1000000]
EXLA.Backend<cuda:0, 0.3207131337.2747400258.148090>
[414792, 7856, 167052, 997910, 477384, 955359, 376134, 407064, 191101, 211793, 661455, 589215, 427479, 422594, 500467, 913978, 849455, 30833, 278802, 211820, 558828, 862956, 399076, 798475, 577425, 741903, 462995, 415788, 347169, 712535, 758138, 331749, 358055, 733199, 338876, 623129, 704948, 676309, 95260, 823563, 152641, 448626, 726940, 293851, 504792, 381313, 263463, 798364, 930618, ...]
>,
#Nx.Tensor<
u32[2]
EXLA.Backend<host:0, 0.3207131337.2747400256.146137>
[2698884502, 3718367942]
>}
iex(5)> output |> Nx.sort() |> Nx.equal(input) |> Nx.all()
#Nx.Tensor<
u8
EXLA.Backend<cuda:0, 0.3207131337.2747400258.148094>
1
>
Me and @jonatanklosko found the culprit (mostly Jonatan, tho). We were using the wrong compare_type
attribute in the comparison functions. Don't ask me why this didn't break earlier and nowhere else though.
We took the opportunity to port over some fixes from the metal plugin branch
When running the following code on an EXLA backend with a CUDA GPU, through the
livebook:0.14.5-cuda12
image:the result looks like this:
The bug also occurs for smaller tensors, but with a lesser frequency. The deciding factor is the size of the axis used for the shuffle, the other dimensions of the tensor do not seem to be relevant. It seems to start happening with an axis size around
100,000
, and is guaranteed after1,000,000
.Moreover, about half the time the execution never completes and the Livebook runtime has to be restarted.
This has been observed on two different machines, one with an RTX4090 graphics card and one with a GTX1070ti. The bug did not occur during testing on the CPU on the same machines.