triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.47k stars 1.66k forks source link

CUDA_ERROR_ILLEGAL_ADDRESS on certain small tile sizes #5125

Open Moerafaat opened 1 week ago

Moerafaat commented 1 week ago

The following TTGIR currently fails with CUDA_ERROR_ILLEGAL_ADDRESS.

This would fail in a different way before #5044 where you would instead get the error reported here https://github.com/triton-lang/triton/issues/3435 which is mitigated by applying the changes here #4768.

Note that the TTGIR at this step before #5044 and after is identical, so the changes only happen while/after lowering to LLVM.

// -----// IR Dump Before ConvertTritonGPUToLLVM (convert-triton-gpu-to-llvm) ('builtin.module' operation) //----- //
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [2, 1], order = [1, 0]}>
#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 2], instrShape = [16, 8]}>
#shared = #triton_gpu.shared<{vec = 16, perPhase = 8, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}>
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 8, maxPhase = 2, order = [1, 0], hasLeadingOffset = false}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.shared = 640 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} {
  tt.func @a_impl(%arg0: !tt.ptr<f8E4M3FN> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f8E4M3FN> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<bf16> {tt.divisibility = 16 : i32}) {
    %cst = arith.constant dense<true> : tensor<16x16xi1, #blocked>
    %c1_i32 = arith.constant 1 : i32
    %cst_0 = arith.constant dense<32> : tensor<16x1xi64, #blocked>
    %c16_i64 = arith.constant 16 : i64
    %c0_i64 = arith.constant 0 : i64
    %c32_i32 = arith.constant 32 : i32
    %c0_i32 = arith.constant 0 : i32
    %c2_i32 = arith.constant 2 : i32
    %c8_i32 = arith.constant 8 : i32
    %c16_i32 = arith.constant 16 : i32
    %cst_1 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma>
    %0 = tt.get_program_id x : i32
    %1 = arith.divsi %0, %c16_i32 : i32
    %2 = arith.muli %1, %c8_i32 : i32
    %3 = arith.subi %c2_i32, %2 : i32
    %4 = arith.cmpi slt, %3, %c8_i32 : i32
    %5 = arith.select %4, %3, %c8_i32 : i32
    %6 = arith.remsi %0, %5 : i32
    %7 = arith.addi %2, %6 : i32
    %8 = arith.remsi %0, %c16_i32 : i32
    %9 = arith.divsi %8, %5 : i32
    %10 = arith.muli %7, %c16_i32 : i32
    %11 = arith.extsi %10 : i32 to i64
    %12 = arith.muli %9, %c16_i32 : i32
    %13 = arith.extsi %12 : i32 to i64
    %14 = tt.splat %arg0 : !tt.ptr<f8E4M3FN> -> tensor<16x16x!tt.ptr<f8E4M3FN>, #blocked>
    %15 = tt.splat %11 : i64 -> tensor<16xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
    %16 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
    %17 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
    %18 = arith.extsi %16 : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> to tensor<16xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
    %19 = arith.extsi %17 : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> to tensor<16xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
    %20 = arith.addi %15, %18 : tensor<16xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
    %21 = tt.expand_dims %20 {axis = 1 : i32} : tensor<16xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi64, #blocked>
    %22 = arith.muli %21, %cst_0 : tensor<16x1xi64, #blocked>
    %23 = tt.broadcast %22 : tensor<16x1xi64, #blocked> -> tensor<16x16xi64, #blocked>
    %24 = tt.splat %arg1 : !tt.ptr<f8E4M3FN> -> tensor<16x16x!tt.ptr<f8E4M3FN>, #blocked>
    %25 = tt.splat %13 : i64 -> tensor<16xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
    %26 = arith.addi %25, %19 : tensor<16xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
    %27 = tt.expand_dims %26 {axis = 0 : i32} : tensor<16xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi64, #blocked>
    %28 = tt.broadcast %27 : tensor<1x16xi64, #blocked> -> tensor<16x16xi64, #blocked>
    %29 = triton_gpu.local_alloc  {allocation.offset = 0 : i32} : () -> !tt.memdesc<1x16x16xf8E4M3FN, #shared, #triton_gpu.shared_memory, mutable>
    %30 = triton_gpu.local_alloc  {allocation.offset = 256 : i32} : () -> !tt.memdesc<1x16x16xf8E4M3FN, #shared1, #triton_gpu.shared_memory, mutable>
    %31 = tt.expand_dims %19 {axis = 0 : i32} : tensor<16xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi64, #blocked>
    %32 = tt.broadcast %31 : tensor<1x16xi64, #blocked> -> tensor<16x16xi64, #blocked>
    %33 = arith.addi %23, %32 : tensor<16x16xi64, #blocked>
    %34 = tt.addptr %14, %33 : tensor<16x16x!tt.ptr<f8E4M3FN>, #blocked>, tensor<16x16xi64, #blocked>
    %35 = triton_gpu.memdesc_subview %29[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x16x16xf8E4M3FN, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf8E4M3FN, #shared, #triton_gpu.shared_memory, mutable>
    %36 = triton_gpu.async_copy_global_to_local %34, %35 mask %cst : tensor<16x16x!tt.ptr<f8E4M3FN>, #blocked> -> <16x16xf8E4M3FN, #shared, #triton_gpu.shared_memory, mutable>
    %37 = triton_gpu.async_commit_group %36
    %38 = tt.expand_dims %18 {axis = 1 : i32} : tensor<16xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi64, #blocked>
    %39 = arith.muli %38, %cst_0 : tensor<16x1xi64, #blocked>
    %40 = tt.broadcast %39 : tensor<16x1xi64, #blocked> -> tensor<16x16xi64, #blocked>
    %41 = arith.addi %40, %28 : tensor<16x16xi64, #blocked>
    %42 = tt.addptr %24, %41 : tensor<16x16x!tt.ptr<f8E4M3FN>, #blocked>, tensor<16x16xi64, #blocked>
    %43 = triton_gpu.memdesc_subview %30[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x16x16xf8E4M3FN, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf8E4M3FN, #shared1, #triton_gpu.shared_memory, mutable>
    %44 = triton_gpu.async_copy_global_to_local %42, %43 mask %cst : tensor<16x16x!tt.ptr<f8E4M3FN>, #blocked> -> <16x16xf8E4M3FN, #shared1, #triton_gpu.shared_memory, mutable>
    %45 = triton_gpu.async_commit_group %44
    %46 = triton_gpu.async_wait %45 {num = 0 : i32}
    cf.br ^bb1(%c0_i32, %c0_i64, %c0_i64, %cst_1, %c0_i32, %c0_i32, %35, %43 : i32, i64, i64, tensor<16x16xf32, #mma>, i32, i32, !tt.memdesc<16x16xf8E4M3FN, #shared, #triton_gpu.shared_memory, mutable>, !tt.memdesc<16x16xf8E4M3FN, #shared1, #triton_gpu.shared_memory, mutable>)
  ^bb1(%47: i32, %48: i64, %49: i64, %50: tensor<16x16xf32, #mma>, %51: i32, %52: i32, %53: !tt.memdesc<16x16xf8E4M3FN, #shared, #triton_gpu.shared_memory, mutable>, %54: !tt.memdesc<16x16xf8E4M3FN, #shared1, #triton_gpu.shared_memory, mutable>):  // 2 preds: ^bb0, ^bb2
    %55 = arith.cmpi slt, %47, %c32_i32 : i32
    cf.cond_br %55, ^bb2, ^bb3
  ^bb2:  // pred: ^bb1
    %56 = arith.cmpi slt, %47, %c16_i32 : i32
    %57 = triton_gpu.local_load %53 : !tt.memdesc<16x16xf8E4M3FN, #shared, #triton_gpu.shared_memory, mutable> -> tensor<16x16xf8E4M3FN, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
    %58 = triton_gpu.local_load %54 : !tt.memdesc<16x16xf8E4M3FN, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<16x16xf8E4M3FN, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
    %59 = tt.fp_to_fp %57 : tensor<16x16xf8E4M3FN, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
    %60 = tt.fp_to_fp %58 : tensor<16x16xf8E4M3FN, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
    %61 = tt.dot %59, %60, %50, inputPrecision = tf32 {maxNumImpreciseAcc = 2147483647 : i32} : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<16x16xf32, #mma>
    %62 = arith.addi %48, %c16_i64 : i64
    %63 = arith.addi %49, %c16_i64 : i64
    %64 = arith.addi %51, %c1_i32 : i32
    %65 = arith.cmpi slt, %64, %c1_i32 : i32
    %66 = arith.select %65, %64, %c0_i32 : i32
    %67 = tt.splat %62 : i64 -> tensor<16xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
    %68 = arith.addi %67, %19 : tensor<16xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
    %69 = tt.expand_dims %68 {axis = 0 : i32} : tensor<16xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi64, #blocked>
    %70 = tt.broadcast %69 : tensor<1x16xi64, #blocked> -> tensor<16x16xi64, #blocked>
    %71 = arith.addi %23, %70 : tensor<16x16xi64, #blocked>
    %72 = tt.addptr %14, %71 : tensor<16x16x!tt.ptr<f8E4M3FN>, #blocked>, tensor<16x16xi64, #blocked>
    %73 = triton_gpu.memdesc_subview %29[%66, %c0_i32, %c0_i32] : !tt.memdesc<1x16x16xf8E4M3FN, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf8E4M3FN, #shared, #triton_gpu.shared_memory, mutable>
    %74 = tt.splat %56 : i1 -> tensor<16x16xi1, #blocked>
    %75 = triton_gpu.async_copy_global_to_local %72, %73 mask %74 : tensor<16x16x!tt.ptr<f8E4M3FN>, #blocked> -> <16x16xf8E4M3FN, #shared, #triton_gpu.shared_memory, mutable>
    %76 = triton_gpu.async_commit_group %75
    %77 = tt.splat %63 : i64 -> tensor<16xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
    %78 = arith.addi %77, %18 : tensor<16xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
    %79 = tt.expand_dims %78 {axis = 1 : i32} : tensor<16xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi64, #blocked>
    %80 = arith.muli %79, %cst_0 : tensor<16x1xi64, #blocked>
    %81 = tt.broadcast %80 : tensor<16x1xi64, #blocked> -> tensor<16x16xi64, #blocked>
    %82 = arith.addi %81, %28 : tensor<16x16xi64, #blocked>
    %83 = tt.addptr %24, %82 : tensor<16x16x!tt.ptr<f8E4M3FN>, #blocked>, tensor<16x16xi64, #blocked>
    %84 = triton_gpu.memdesc_subview %30[%66, %c0_i32, %c0_i32] : !tt.memdesc<1x16x16xf8E4M3FN, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf8E4M3FN, #shared1, #triton_gpu.shared_memory, mutable>
    %85 = triton_gpu.async_copy_global_to_local %83, %84 mask %74 : tensor<16x16x!tt.ptr<f8E4M3FN>, #blocked> -> <16x16xf8E4M3FN, #shared1, #triton_gpu.shared_memory, mutable>
    %86 = triton_gpu.async_commit_group %85
    %87 = arith.addi %52, %c1_i32 : i32
    %88 = arith.cmpi slt, %87, %c1_i32 : i32
    %89 = arith.select %88, %87, %c0_i32 : i32
    %90 = triton_gpu.memdesc_subview %29[%89, %c0_i32, %c0_i32] : !tt.memdesc<1x16x16xf8E4M3FN, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf8E4M3FN, #shared, #triton_gpu.shared_memory, mutable>
    %91 = triton_gpu.async_wait %86 {num = 0 : i32}
    %92 = triton_gpu.memdesc_subview %30[%89, %c0_i32, %c0_i32] : !tt.memdesc<1x16x16xf8E4M3FN, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf8E4M3FN, #shared1, #triton_gpu.shared_memory, mutable>
    %93 = arith.addi %47, %c16_i32 : i32
    cf.br ^bb1(%93, %62, %63, %61, %66, %89, %90, %92 : i32, i64, i64, tensor<16x16xf32, #mma>, i32, i32, !tt.memdesc<16x16xf8E4M3FN, #shared, #triton_gpu.shared_memory, mutable>, !tt.memdesc<16x16xf8E4M3FN, #shared1, #triton_gpu.shared_memory, mutable>)
  ^bb3:  // pred: ^bb1
    %94 = triton_gpu.async_wait  {num = 0 : i32}
    triton_gpu.local_dealloc %29 : !tt.memdesc<1x16x16xf8E4M3FN, #shared, #triton_gpu.shared_memory, mutable>
    triton_gpu.local_dealloc %30 : !tt.memdesc<1x16x16xf8E4M3FN, #shared1, #triton_gpu.shared_memory, mutable>
    %95 = arith.truncf %50 : tensor<16x16xf32, #mma> to tensor<16x16xbf16, #mma>
    %96 = tt.splat %arg2 : !tt.ptr<bf16> -> tensor<16x16x!tt.ptr<bf16>, #blocked>
    %97 = arith.addi %23, %28 : tensor<16x16xi64, #blocked>
    %98 = tt.addptr %96, %97 : tensor<16x16x!tt.ptr<bf16>, #blocked>, tensor<16x16xi64, #blocked>
    %99 = triton_gpu.convert_layout %95 {allocation.offset = 0 : i32} : tensor<16x16xbf16, #mma> -> tensor<16x16xbf16, #blocked>
    tt.store %98, %99 : tensor<16x16x!tt.ptr<bf16>, #blocked>
    tt.return
  }
}

Looking at the differences in the LLVM IR, I can notice that these are extra in the failing one, so perhaps we need to also remove these extra ones similar to how it was done in the old logic here #4768.

    %168 = llvm.extractvalue %165[2] : !llvm.struct<(i32, i32, i32, i32)> 
    %169 = llvm.extractvalue %165[3] : !llvm.struct<(i32, i32, i32, i32)> 
    %170 = llvm.bitcast %166 : i32 to vector<4xi8>
    %171 = llvm.extractelement %170[%23 : i32] : vector<4xi8>
    %172 = llvm.extractelement %170[%28 : i32] : vector<4xi8>
    %173 = llvm.bitcast %167 : i32 to vector<4xi8>
    %174 = llvm.extractelement %173[%23 : i32] : vector<4xi8>
    %175 = llvm.extractelement %173[%28 : i32] : vector<4xi8>
    %176 = llvm.bitcast %168 : i32 to vector<4xi8>
    %177 = llvm.extractelement %176[%23 : i32] : vector<4xi8>
    %178 = llvm.extractelement %176[%28 : i32] : vector<4xi8>
    %179 = llvm.bitcast %169 : i32 to vector<4xi8>
    %180 = llvm.extractelement %179[%23 : i32] : vector<4xi8>
    %181 = llvm.extractelement %179[%28 : i32] : vector<4xi8>
Jokeren commented 1 week ago

I can confirm at least https://github.com/triton-lang/triton/pull/5121 doesn't crash

Jokeren commented 1 week ago

Need an end to end example to reproduce illegal memory access