elixir-nx / nx

Multi-dimensional arrays (tensors) and numerical definitions for Elixir
2.66k stars 194 forks source link

`Nx.Random.shuffle/3` fails for large tensors on `cuda` backend #1551

Closed VictorT42 closed 3 weeks ago

VictorT42 commented 3 weeks ago

When running the following code on an EXLA backend with a CUDA GPU, through the livebook:0.14.5-cuda12 image:

key = Nx.Random.key(1)
input = Nx.iota({1_000_000})
{output, _new_key} = Nx.Random.shuffle(key, input)
output

the result looks like this:

#Nx.Tensor<
  s32[1000000]
  EXLA.Backend<cuda:0, 0.477854050.2664562754.2456>
  [163918, 163918, 163918, 163918, 163918, 163918, 163918, 163918, 163918, 163918, 163918, 163918, 163918, 163918, 163918, 163918, 163918, 163918, 163918, 163918, 163918, 163918, 163918, 163918, 163918, 163918, 163918, 163918, 163918, 163918, 163918, 163918, 163918, 163918, 163918, 163918, 163918, 163918, 163918, 163918, 163918, 163918, 163918, 163918, 163918, 163918, 163918, 163918, 163918, 163918, ...]
>

The bug also occurs for smaller tensors, but with a lesser frequency. The deciding factor is the size of the axis used for the shuffle, the other dimensions of the tensor do not seem to be relevant. It seems to start happening with an axis size around 100,000, and is guaranteed after 1,000,000.

Moreover, about half the time the execution never completes and the Livebook runtime has to be restarted.

This has been observed on two different machines, one with an RTX4090 graphics card and one with a GTX1070ti. The bug did not occur during testing on the CPU on the same machines.

polvalente commented 3 weeks ago

Given that this works both on EXLA on CPU and on Nx.BinaryBackend, everything points towards some upstream bug on XLA

polvalente commented 3 weeks ago

@jonatanklosko any ideas on what we should report there?

This is the MLIR module for &Nx.Random.shuffle/2:

module {
  func.func public @main(%arg0: tensor<2xui32>, %arg1: tensor<1000000xi32>) -> (tensor<1000000xi32>, tensor<2xui32>) {
    %c = stablehlo.constant dense<0> : tensor<i32>
    %0 = stablehlo.iota dim = 0 : tensor<1000000xi32>
    %1 = stablehlo.reshape %0 : (tensor<1000000xi32>) -> tensor<1000000xi32>
    %2:3 = stablehlo.while(%iterArg = %c, %iterArg_0 = %1, %iterArg_1 = %arg0) : tensor<i32>, tensor<1000000xi32>, tensor<2xui32>
     cond {
      %c_2 = stablehlo.constant dense<2> : tensor<ui32>
      %4 = stablehlo.convert %iterArg : (tensor<i32>) -> tensor<i64>
      %5 = stablehlo.convert %c_2 : (tensor<ui32>) -> tensor<i64>
      %6 = stablehlo.compare  LT, %4, %5,  NOTYPE : (tensor<i64>, tensor<i64>) -> tensor<i1>
      stablehlo.return %6 : tensor<i1>
    } do {
      %c_2 = stablehlo.constant dense<1> : tensor<i32>
      %4 = stablehlo.add %c_2, %iterArg : tensor<i32>
      %c_3 = stablehlo.constant dense<1> : tensor<ui32>
      %5 = stablehlo.iota dim = 0 : tensor<1000000xui32>
      %6 = stablehlo.reshape %5 : (tensor<1000000xui32>) -> tensor<2x500000xui32>
      %7 = stablehlo.slice %6 [0:1, 0:500000] : (tensor<2x500000xui32>) -> tensor<1x500000xui32>
      %8 = stablehlo.reshape %7 : (tensor<1x500000xui32>) -> tensor<500000xui32>
      %9 = stablehlo.iota dim = 0 : tensor<4xui32>
      %10 = stablehlo.reshape %9 : (tensor<4xui32>) -> tensor<2x2xui32>
      %11 = stablehlo.slice %10 [0:1, 0:2] : (tensor<2x2xui32>) -> tensor<1x2xui32>
      %12 = stablehlo.reshape %11 : (tensor<1x2xui32>) -> tensor<2xui32>
      %13 = stablehlo.slice %iterArg_1 [0:1] : (tensor<2xui32>) -> tensor<1xui32>
      %14 = stablehlo.reshape %13 : (tensor<1xui32>) -> tensor<ui32>
      %15 = stablehlo.broadcast_in_dim %14, dims = [] : (tensor<ui32>) -> tensor<2xui32>
      %16 = stablehlo.add %12, %15 : tensor<2xui32>
      %17 = stablehlo.slice %10 [1:2, 0:2] : (tensor<2x2xui32>) -> tensor<1x2xui32>
      %18 = stablehlo.reshape %17 : (tensor<1x2xui32>) -> tensor<2xui32>
      %19 = stablehlo.slice %iterArg_1 [1:2] : (tensor<2xui32>) -> tensor<1xui32>
      %20 = stablehlo.reshape %19 : (tensor<1xui32>) -> tensor<ui32>
      %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor<ui32>) -> tensor<2xui32>
      %22 = stablehlo.add %18, %21 : tensor<2xui32>
      %23 = stablehlo.xor %14, %20 : tensor<ui32>
      %c_4 = stablehlo.constant dense<466688986> : tensor<i32>
      %24 = stablehlo.convert %23 : (tensor<ui32>) -> tensor<i64>
      %25 = stablehlo.convert %c_4 : (tensor<i32>) -> tensor<i64>
      %26 = stablehlo.xor %24, %25 : tensor<i64>
      %27 = stablehlo.convert %26 : (tensor<i64>) -> tensor<ui32>
      %c_5 = stablehlo.constant dense<[13, 15, 26, 6]> : tensor<4xui8>
      %c_6 = stablehlo.constant dense<[17, 29, 16, 24]> : tensor<4xui8>
      %28:8 = stablehlo.while(%iterArg_9 = %c_3, %iterArg_10 = %16, %iterArg_11 = %22, %iterArg_12 = %20, %iterArg_13 = %27, %iterArg_14 = %14, %iterArg_15 = %c_5, %iterArg_16 = %c_6) : tensor<ui32>, tensor<2xui32>, tensor<2xui32>, tensor<ui32>, tensor<ui32>, tensor<ui32>, tensor<4xui8>, tensor<4xui8>
       cond {
        %c_17 = stablehlo.constant dense<6> : tensor<i32>
        %53 = stablehlo.convert %iterArg_9 : (tensor<ui32>) -> tensor<i64>
        %54 = stablehlo.convert %c_17 : (tensor<i32>) -> tensor<i64>
        %55 = stablehlo.compare  LT, %53, %54,  NOTYPE : (tensor<i64>, tensor<i64>) -> tensor<i1>
        stablehlo.return %55 : tensor<i1>
      } do {
        %c_17 = stablehlo.constant dense<1> : tensor<ui32>
        %53 = stablehlo.add %c_17, %iterArg_9 : tensor<ui32>
        %c_18 = stablehlo.constant dense<0> : tensor<i32>
        %54:4 = stablehlo.while(%iterArg_19 = %c_18, %iterArg_20 = %iterArg_15, %iterArg_21 = %iterArg_10, %iterArg_22 = %iterArg_11) : tensor<i32>, tensor<4xui8>, tensor<2xui32>, tensor<2xui32>
         cond {
          %c_23 = stablehlo.constant dense<3> : tensor<i32>
          %60 = stablehlo.compare  LE, %iterArg_19, %c_23,  NOTYPE : (tensor<i32>, tensor<i32>) -> tensor<i1>
          stablehlo.return %60 : tensor<i1>
        } do {
          %c_23 = stablehlo.constant dense<1> : tensor<i32>
          %60 = stablehlo.add %c_23, %iterArg_19 : tensor<i32>
          %61 = stablehlo.add %iterArg_21, %iterArg_22 : tensor<2xui32>
          %62 = stablehlo.dynamic_slice %iterArg_20, %iterArg_19, sizes = [1] : (tensor<4xui8>, tensor<i32>) -> tensor<1xui8>
          %63 = stablehlo.reshape %62 : (tensor<1xui8>) -> tensor<ui8>
          %64 = stablehlo.convert %63 : (tensor<ui8>) -> tensor<ui32>
          %65 = stablehlo.broadcast_in_dim %64, dims = [] : (tensor<ui32>) -> tensor<2xui32>
          %66 = stablehlo.shift_left %iterArg_22, %65 : tensor<2xui32>
          %c_24 = stablehlo.constant dense<32> : tensor<ui32>
          %67 = stablehlo.convert %63 : (tensor<ui8>) -> tensor<ui32>
          %68 = stablehlo.subtract %c_24, %67 : tensor<ui32>
          %69 = stablehlo.broadcast_in_dim %68, dims = [] : (tensor<ui32>) -> tensor<2xui32>
          %70 = stablehlo.shift_right_logical %iterArg_22, %69 : tensor<2xui32>
          %71 = stablehlo.or %66, %70 : tensor<2xui32>
          %72 = stablehlo.xor %71, %61 : tensor<2xui32>
          stablehlo.return %60, %iterArg_20, %61, %72 : tensor<i32>, tensor<4xui8>, tensor<2xui32>, tensor<2xui32>
        }
        %55 = stablehlo.broadcast_in_dim %iterArg_12, dims = [] : (tensor<ui32>) -> tensor<2xui32>
        %56 = stablehlo.add %55, %54#2 : tensor<2xui32>
        %57 = stablehlo.add %iterArg_13, %iterArg_9 : tensor<ui32>
        %58 = stablehlo.broadcast_in_dim %57, dims = [] : (tensor<ui32>) -> tensor<2xui32>
        %59 = stablehlo.add %58, %54#3 : tensor<2xui32>
        stablehlo.return %53, %56, %59, %iterArg_13, %iterArg_14, %iterArg_12, %iterArg_16, %iterArg_15 : tensor<ui32>, tensor<2xui32>, tensor<2xui32>, tensor<ui32>, tensor<ui32>, tensor<ui32>, tensor<4xui8>, tensor<4xui8>
      }
      %29 = stablehlo.concatenate %28#1, %28#2, dim = 0 : (tensor<2xui32>, tensor<2xui32>) -> tensor<4xui32>
      %30 = stablehlo.reshape %29 : (tensor<4xui32>) -> tensor<2x2xui32>
      %31 = stablehlo.slice %30 [1:2, 0:1] : (tensor<2x2xui32>) -> tensor<1x1xui32>
      %32 = stablehlo.reshape %31 : (tensor<1x1xui32>) -> tensor<ui32>
      %33 = stablehlo.broadcast_in_dim %32, dims = [] : (tensor<ui32>) -> tensor<500000xui32>
      %34 = stablehlo.add %8, %33 : tensor<500000xui32>
      %35 = stablehlo.slice %6 [1:2, 0:500000] : (tensor<2x500000xui32>) -> tensor<1x500000xui32>
      %36 = stablehlo.reshape %35 : (tensor<1x500000xui32>) -> tensor<500000xui32>
      %37 = stablehlo.slice %30 [1:2, 1:2] : (tensor<2x2xui32>) -> tensor<1x1xui32>
      %38 = stablehlo.reshape %37 : (tensor<1x1xui32>) -> tensor<ui32>
      %39 = stablehlo.broadcast_in_dim %38, dims = [] : (tensor<ui32>) -> tensor<500000xui32>
      %40 = stablehlo.add %36, %39 : tensor<500000xui32>
      %41 = stablehlo.xor %32, %38 : tensor<ui32>
      %42 = stablehlo.convert %41 : (tensor<ui32>) -> tensor<i64>
      %43 = stablehlo.convert %c_4 : (tensor<i32>) -> tensor<i64>
      %44 = stablehlo.xor %42, %43 : tensor<i64>
      %45 = stablehlo.convert %44 : (tensor<i64>) -> tensor<ui32>
      %c_7 = stablehlo.constant dense<[13, 15, 26, 6]> : tensor<4xui8>
      %c_8 = stablehlo.constant dense<[17, 29, 16, 24]> : tensor<4xui8>
      %46:8 = stablehlo.while(%iterArg_9 = %c_3, %iterArg_10 = %34, %iterArg_11 = %40, %iterArg_12 = %38, %iterArg_13 = %45, %iterArg_14 = %32, %iterArg_15 = %c_7, %iterArg_16 = %c_8) : tensor<ui32>, tensor<500000xui32>, tensor<500000xui32>, tensor<ui32>, tensor<ui32>, tensor<ui32>, tensor<4xui8>, tensor<4xui8>
       cond {
        %c_17 = stablehlo.constant dense<6> : tensor<i32>
        %53 = stablehlo.convert %iterArg_9 : (tensor<ui32>) -> tensor<i64>
        %54 = stablehlo.convert %c_17 : (tensor<i32>) -> tensor<i64>
        %55 = stablehlo.compare  LT, %53, %54,  NOTYPE : (tensor<i64>, tensor<i64>) -> tensor<i1>
        stablehlo.return %55 : tensor<i1>
      } do {
        %c_17 = stablehlo.constant dense<1> : tensor<ui32>
        %53 = stablehlo.add %c_17, %iterArg_9 : tensor<ui32>
        %c_18 = stablehlo.constant dense<0> : tensor<i32>
        %54:4 = stablehlo.while(%iterArg_19 = %c_18, %iterArg_20 = %iterArg_15, %iterArg_21 = %iterArg_10, %iterArg_22 = %iterArg_11) : tensor<i32>, tensor<4xui8>, tensor<500000xui32>, tensor<500000xui32>
         cond {
          %c_23 = stablehlo.constant dense<3> : tensor<i32>
          %60 = stablehlo.compare  LE, %iterArg_19, %c_23,  NOTYPE : (tensor<i32>, tensor<i32>) -> tensor<i1>
          stablehlo.return %60 : tensor<i1>
        } do {
          %c_23 = stablehlo.constant dense<1> : tensor<i32>
          %60 = stablehlo.add %c_23, %iterArg_19 : tensor<i32>
          %61 = stablehlo.add %iterArg_21, %iterArg_22 : tensor<500000xui32>
          %62 = stablehlo.dynamic_slice %iterArg_20, %iterArg_19, sizes = [1] : (tensor<4xui8>, tensor<i32>) -> tensor<1xui8>
          %63 = stablehlo.reshape %62 : (tensor<1xui8>) -> tensor<ui8>
          %64 = stablehlo.convert %63 : (tensor<ui8>) -> tensor<ui32>
          %65 = stablehlo.broadcast_in_dim %64, dims = [] : (tensor<ui32>) -> tensor<500000xui32>
          %66 = stablehlo.shift_left %iterArg_22, %65 : tensor<500000xui32>
          %c_24 = stablehlo.constant dense<32> : tensor<ui32>
          %67 = stablehlo.convert %63 : (tensor<ui8>) -> tensor<ui32>
          %68 = stablehlo.subtract %c_24, %67 : tensor<ui32>
          %69 = stablehlo.broadcast_in_dim %68, dims = [] : (tensor<ui32>) -> tensor<500000xui32>
          %70 = stablehlo.shift_right_logical %iterArg_22, %69 : tensor<500000xui32>
          %71 = stablehlo.or %66, %70 : tensor<500000xui32>
          %72 = stablehlo.xor %71, %61 : tensor<500000xui32>
          stablehlo.return %60, %iterArg_20, %61, %72 : tensor<i32>, tensor<4xui8>, tensor<500000xui32>, tensor<500000xui32>
        }
        %55 = stablehlo.broadcast_in_dim %iterArg_12, dims = [] : (tensor<ui32>) -> tensor<500000xui32>
        %56 = stablehlo.add %55, %54#2 : tensor<500000xui32>
        %57 = stablehlo.add %iterArg_13, %iterArg_9 : tensor<ui32>
        %58 = stablehlo.broadcast_in_dim %57, dims = [] : (tensor<ui32>) -> tensor<500000xui32>
        %59 = stablehlo.add %58, %54#3 : tensor<500000xui32>
        stablehlo.return %53, %56, %59, %iterArg_13, %iterArg_14, %iterArg_12, %iterArg_16, %iterArg_15 : tensor<ui32>, tensor<500000xui32>, tensor<500000xui32>, tensor<ui32>, tensor<ui32>, tensor<ui32>, tensor<4xui8>, tensor<4xui8>
      }
      %47 = stablehlo.concatenate %46#1, %46#2, dim = 0 : (tensor<500000xui32>, tensor<500000xui32>) -> tensor<1000000xui32>
      %48 = stablehlo.iota dim = 0 : tensor<1000000xi32>
      %49:2 = "stablehlo.sort"(%47, %48) <{dimension = 0 : i64, is_stable = false}> ({
      ^bb0(%arg2: tensor<ui32>, %arg3: tensor<ui32>, %arg4: tensor<i32>, %arg5: tensor<i32>):
        %53 = stablehlo.compare  LT, %arg2, %arg3,  NOTYPE : (tensor<ui32>, tensor<ui32>) -> tensor<i1>
        stablehlo.return %53 : tensor<i1>
      }) : (tensor<1000000xui32>, tensor<1000000xi32>) -> (tensor<1000000xui32>, tensor<1000000xi32>)
      %50 = func.call @main_optional_34050(%iterArg_0, %49#1) : (tensor<1000000xi32>, tensor<1000000xi32>) -> tensor<1000000xi32>
      %51 = stablehlo.slice %30 [0:1, 0:2] : (tensor<2x2xui32>) -> tensor<1x2xui32>
      %52 = stablehlo.reshape %51 : (tensor<1x2xui32>) -> tensor<2xui32>
      stablehlo.return %4, %50, %52 : tensor<i32>, tensor<1000000xi32>, tensor<2xui32>
    }
    %3 = "stablehlo.gather"(%arg1, %2#1) <{dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1>}> : (tensor<1000000xi32>, tensor<1000000xi32>) -> tensor<1000000xi32>
    return %3, %2#2 : tensor<1000000xi32>, tensor<2xui32>
  }
  func.func nested @main_optional_34050(%arg0: tensor<1000000xi32>, %arg1: tensor<1000000xi32>) -> tensor<1000000xi32> {
    %0 = stablehlo.reshape %arg1 : (tensor<1000000xi32>) -> tensor<1000000x1xi32>
    %1 = "stablehlo.gather"(%arg0, %0) <{dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1>}> : (tensor<1000000xi32>, tensor<1000000x1xi32>) -> tensor<1000000xi32>
    return %1 : tensor<1000000xi32>
  }
}
polvalente commented 3 weeks ago

I had a realization that maybe we had gotten some edge case wrong in the implementation. Most of our implementation for shuffle is a line-by-line translation of the Jax implementation.

However, Jax has a variadic sort function that sorts N tensors based on the sorting of the first one, and we don't have that, so we emulate it with argsort + take. This in turn has a really specific edge-case when the input is 1D, which the linked PR should fix.

This works on my GPU now:

iex(1)> Nx.default_backend({EXLA.Backend, client: :cuda})
{Nx.BinaryBackend, []}
iex(2)> key = Nx.Random.key(1)
#Nx.Tensor<
  u32[2]
  EXLA.Backend<host:0, 0.3207131337.2747400256.146062>
  [0, 1]
>
iex(3)> input = Nx.iota({1_000_000})
#Nx.Tensor<
  s32[1000000]
  EXLA.Backend<cuda:0, 0.3207131337.2747400258.146736>
  [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, ...]
>
iex(4)> {output, _new_key} = Nx.Random.shuffle(key, input)
{#Nx.Tensor<
   s32[1000000]
   EXLA.Backend<cuda:0, 0.3207131337.2747400258.148090>
   [414792, 7856, 167052, 997910, 477384, 955359, 376134, 407064, 191101, 211793, 661455, 589215, 427479, 422594, 500467, 913978, 849455, 30833, 278802, 211820, 558828, 862956, 399076, 798475, 577425, 741903, 462995, 415788, 347169, 712535, 758138, 331749, 358055, 733199, 338876, 623129, 704948, 676309, 95260, 823563, 152641, 448626, 726940, 293851, 504792, 381313, 263463, 798364, 930618, ...]
 >,
 #Nx.Tensor<
   u32[2]
   EXLA.Backend<host:0, 0.3207131337.2747400256.146137>
   [2698884502, 3718367942]
 >}
iex(5)> output |> Nx.sort() |> Nx.equal(input) |> Nx.all()
#Nx.Tensor<
  u8
  EXLA.Backend<cuda:0, 0.3207131337.2747400258.148094>
  1
>
polvalente commented 3 weeks ago

Me and @jonatanklosko found the culprit (mostly Jonatan, tho). We were using the wrong compare_type attribute in the comparison functions. Don't ask me why this didn't break earlier and nowhere else though.

We took the opportunity to port over some fixes from the metal plugin branch