triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
12.74k stars 1.54k forks source link

Custom tensor roll kernel produces wrong results for variations of block size and tl.constexpr #2941

Open Marks101 opened 8 months ago

Marks101 commented 8 months ago

Dear triton team,

I am currently debugging an issue with a kernel that is supposed to replace torch.roll followed by zeroing out the first row of a 2D matrix. This is the code I have:

import torch

import triton
import triton.language as tl

@triton.jit
def triton_roll_and_zero_first_row_kernel(in_ptr, out_ptr, NUM_ROWS: tl.constexpr, XBLOCK : tl.constexpr):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)
    row_idx = xindex // 2
    col_idx = xindex % 2
    rolled_row = (row_idx + 1) % NUM_ROWS
    rolled_xindex = 2 * rolled_row + col_idx
    result = tl.load(in_ptr + rolled_xindex)
    result = tl.where(row_idx == 0, 0.0, result)
    tl.store(out_ptr + xindex, result)

def triton_roll_and_zero_first_row(x):
    assert x.size(1) == 2
    y = torch.empty_like(x)
    grid = lambda meta: (triton.cdiv(inp.numel(), meta['XBLOCK']), )
    triton_roll_and_zero_first_row_kernel[grid](x, y, NUM_ROWS=x.size(0), XBLOCK=256)
    return y

def roll_and_zero_first_row(x):
    x = torch.roll(x, -1, 0)
    x[0].fill_(0.0)
    return x

if __name__ == "__main__":
    inp = torch.rand(1024, 2, device="cuda")

    out_eager = roll_and_zero_first_row(inp)
    out_triton = triton_roll_and_zero_first_row(inp)

    print("eager", out_eager)
    print("triton", out_triton)
    assert torch.all(out_eager == out_triton)
    print("PASSED")

In this form the assert does not pass on my system. It passes if I either set XBLOCK=128 or if I remove tl.constexpr from NUM_ROWS. This strikes me as odd. Could you please help me to understand this behaviour?

The blueprint for this kernel was actually produced applying torch.compile to roll_and_zero_first_row and it is affected by this issue as well. I just renamed a few variables and reordered code to make it human-readable. If you can confirm this behavior I would open an issue at pytorch as well.

I am currently on triton-nightly==2.1.0.post20240108192258, torch==2.1.2 and CUDA 12.1

bdhirsh commented 8 months ago

fwiw - I tried stepping through the kernel with interpreter mode, and noticed that the repro above passes when I use TRITON_INTERPRET=1 (but fails without interpreter mode)

manman-ren commented 8 months ago

I am able to repro this on main branch. Comparing between XBLOCK=128 and XBLOCK=256, ttir/ttgir look reasonable, the only difference in ttgir is sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4] vs sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4]

The differences in llir also look reasonable. One difference in ptx is that for XBLOCK=128, there is extra logic on ctaid.x: bfe.u32 %r8, %r1, 24, 1;

It is still not clear to me why XBLOCK=256 doesn't work.

Attaching the ptx files. b256-ptx.txt b128-ptx.txt

manman-ren commented 8 months ago

If we "export DISABLE_LLVM_OPT=1", the test case will pass. So it is related to llvm optimizations.

manman-ren commented 7 months ago

If we disable InstCombine in optimize_module, it will pass. But I haven't figured out why yet. It is not obvious to me by looking at the differences in ptx: disable-opt against with-opt disable-opt-llir.txt disable-opt-ptx.txt with-opt-llir.txt with-opt-ptx.txt

@@ -28,8 +26,8 @@
 $L__tmp0:
    .loc    1 9 36
    mov.u32     %r6, %tid.x;
-   and.b32     %r7, %r6, 127;
-   shl.b32     %r8, %r7, 1;
+   shl.b32     %r7, %r6, 1;
+   and.b32     %r8, %r7, 254;
    .loc    1 8 28
    // begin inline asm
    mov.u32 %r1, %ctaid.x;
@@ -60,19 +58,14 @@
    mov.u32 %r3, 0x0;
    @%p1 ld.global.v2.b32 { %r2, %r3 }, [ %rd1 + 0 ];
    // end inline asm
-   mov.b32     %f1, %r2;
-   mov.b32     %f2, %r3;
    .loc    1 15 33
-   setp.eq.s32     %p3, %r11, 0;
-   .loc    1 15 41
-   selp.f32    %f3, 0f00000000, %f1, %p3;
-   selp.f32    %f4, 0f00000000, %f2, %p3;
+   setp.eq.s32     %p3, %r10, 0;
    .loc    1 16 23
    mul.wide.s32    %rd6, %r10, 4;
    add.s64     %rd2, %rd4, %rd6;
    .loc    1 16 31
-   mov.b32     %r4, %f3;
-   mov.b32     %r5, %f4;
+   selp.b32    %r4, 0, %r2, %p3;
+   selp.b32    %r5, 0, %r3, %p3;
    // begin inline asm
    @%p1 st.global.v2.b32 [ %rd2 + 0 ], { %r4, %r5 };
    // end inline asm
htyu commented 7 months ago

It looks like the predication is different in the two cases.

Disable opt:

    mov.u32 %r1, %ctaid.x;
    shl.b32     %r9, %r1, 8;
    or.b32      %r10, %r9, %r8;
    shr.s32     %r11, %r10, 1;
        setp.eq.s32     %p3, %r11, 0;

With opt:

    mov.u32 %r1, %ctaid.x;
    shl.b32     %r9, %r1, 8;
    or.b32      %r10, %r9, %r8;
        setp.eq.s32     %p3, %r10, 0;
manman-ren commented 7 months ago

Yes, but the problem is that r10 is supposed to be a multiple of 2, so checking it against 0 vs. checking (r10 >> 1) against 0 should be the same?

htyu commented 7 months ago

Yes, but the problem is that r10 is supposed to be a multiple of 2, so checking it against 0 vs. checking (r10 >> 1) against 0 should be the same?

I see, it could make a difference if r10 has value of 1?

jlebar commented 7 months ago

If the ptx looks correct in both cases, then this sounds like a ptxas bug. (Have we tried the latest ptxas?)

manman-ren commented 7 months ago

If the ptx looks correct in both cases, then this sounds like a ptxas bug. (Have we tried the latest ptxas?)

That is possible. I can try it and see if it makes a difference. Mine is currently cuda-12.1.

manman-ren commented 7 months ago

Yes, but the problem is that r10 is supposed to be a multiple of 2, so checking it against 0 vs. checking (r10 >> 1) against 0 should be the same?

I see, it could make a difference if r10 has value of 1? Yes if r10 can be 1.

` mov.u32 %r6, %tid.x; shl.b32 %r7, %r6, 1; and.b32 %r8, %r7, 254;

mov.u32 %r1, %ctaid.x;

shl.b32     %r9, %r1, 8;
or.b32      %r10, %r9, %r8;`

%r8 should be a multiple of 2 and %r9 should be a multiple of 2.

ThomasRaoux commented 7 months ago

If the ptx looks correct in both cases, then this sounds like a ptxas bug. (Have we tried the latest ptxas?)

That is possible. I can try it and see if it makes a difference. Mine is currently cuda-12.1.

you can also disable optimizations in ptxas --opt-level 0

manman-ren commented 7 months ago

Disabling optimization for ptxas also fixed the problem. def-sass.txt no-ptxas-opt-sass.txt

Patch to enable debugging: https://github.com/openai/triton/pull/2995

htyu commented 7 months ago

Disabling optimization for ptxas also fixed the problem. def-sass.txt no-ptxas-opt-sass.txt

Patch to enable debugging: #2995

Thanks for giving it a shot. So the problem went away with LLVM opt on but PTX opt off?

manman-ren commented 7 months ago

A quick summary: 1> default config (with llvm optimizations, with ptxas optimizations) BLOCKSIZE of 128: works; BLOCKSIZE of 256: fails 2> disable llvm optimizations via DISABLE_LLVM_OPT, with ptxas optimizations BLOCKSIZE of 256: works 2a> with llvm optimizations O0, disable InstCombinePass here https://github.com/openai/triton/blob/e6e5d5468e92ed3af3e40babdd55c3da506ab01f/python/src/llvm.cc#L190 BLOCKSIZE of 256: works 3> enable llvm optimizations without ptxas optimizations BLOCKSIZE of 256: works I haven't looked at the differences in sass for item 3 yet.

manman-ren commented 7 months ago

Attempted to debug this with cuda-gdb, but it turned out that the issue is likely in ptxas, the updating of predicate seems to be gone in sass (i.e P0 is not set but it is used).

          IMAD.MOV.U32 R1, RZ, RZ, c[0x0][0x28]
          S2R R0, SR_TID.X
          IMAD.MOV.U32 R6, RZ, RZ, 0x4
          ULDC.64 UR4, c[0x0][0x118]
          S2R R5, SR_CTAID.X
          IMAD.SHL.U32 R0, R0, 0x2, RZ
          LOP3.LUT R0, R0, 0xfe, RZ, 0xc0, !PT
          PRMT R4, R0, 0x6540, R5
          LEA.HI.SX32 R2, R4, 0x1, 0x1f
          SHF.R.S32.HI R3, RZ, 0x1f, R2
          LEA.HI R3, R3, R2, RZ, 0xa
          LOP3.LUT R3, R3, 0x7ffffc00, RZ, 0xc0, !PT
          IMAD.IADD R3, R2, 0x1, -R3
          IMAD.SHL.U32 R3, R3, 0x2, RZ
          IMAD.WIDE R2, R3, R6, c[0x0][0x160]
          LDG.E.64 R2, [R2.64]
          PRMT RZ, R0, 0x6540, R5
          IMAD.WIDE R4, R4, R6, c[0x0][0x168]
          SEL R6, R2, RZ, P0
          SEL R7, R3, RZ, P0
          STG.E.64 [R4.64], R6

Setting of predicate in ptx:

    setp.eq.s32     %p3, %r10, 0;
    .loc    1 16 23
    mul.wide.s32    %rd6, %r10, 4;
    add.s64     %rd2, %rd4, %rd6;
    .loc    1 16 31
    selp.b32    %r4, 0, %r2, %p3;
    selp.b32    %r5, 0, %r3, %p3;
ThomasRaoux commented 7 months ago

Attempted to debug this with cuda-gdb, but it turned out that the issue is likely in ptxas, the updating of predicate seems to be gone in sass (i.e P0 is not set but it is used).

          IMAD.MOV.U32 R1, RZ, RZ, c[0x0][0x28]
          S2R R0, SR_TID.X
          IMAD.MOV.U32 R6, RZ, RZ, 0x4
          ULDC.64 UR4, c[0x0][0x118]
          S2R R5, SR_CTAID.X
          IMAD.SHL.U32 R0, R0, 0x2, RZ
          LOP3.LUT R0, R0, 0xfe, RZ, 0xc0, !PT
          PRMT R4, R0, 0x6540, R5
          LEA.HI.SX32 R2, R4, 0x1, 0x1f
          SHF.R.S32.HI R3, RZ, 0x1f, R2
          LEA.HI R3, R3, R2, RZ, 0xa
          LOP3.LUT R3, R3, 0x7ffffc00, RZ, 0xc0, !PT
          IMAD.IADD R3, R2, 0x1, -R3
          IMAD.SHL.U32 R3, R3, 0x2, RZ
          IMAD.WIDE R2, R3, R6, c[0x0][0x160]
          LDG.E.64 R2, [R2.64]
          PRMT RZ, R0, 0x6540, R5
          IMAD.WIDE R4, R4, R6, c[0x0][0x168]
          SEL R6, R2, RZ, P0
          SEL R7, R3, RZ, P0
          STG.E.64 [R4.64], R6

Setting of predicate in ptx:

  setp.eq.s32     %p3, %r10, 0;
  .loc    1 16 23
  mul.wide.s32    %rd6, %r10, 4;
  add.s64     %rd2, %rd4, %rd6;
  .loc    1 16 31
  selp.b32    %r4, 0, %r2, %p3;
  selp.b32    %r5, 0, %r3, %p3;

interesting findings! Ptxas bugs are always very hard to track :(

Have you tried using the latest ptxas as suggested by @jlebar? (we haven't upgraded to the latest one yet)

You can also run ptxas with different versions and override the cubin generated if it is simpler

manman-ren commented 7 months ago

You can also run ptxas with different versions and override the cubin generated if it is simpler

This will be super useful if we can swap the cubin and run it. How do we do it?

Have you tried using the latest ptxas as suggested by @jlebar? (we haven't upgraded to the latest one yet)

I haven't. Let me ask around how to get the latest.

ThomasRaoux commented 7 months ago

You can also run ptxas with different versions and override the cubin generated if it is simpler

This will be super useful if we can swap the cubin and run it. How do we do it?

It's a bit of an experimental debug feature but here: https://github.com/openai/triton/blob/main/python/triton/compiler/compiler.py#L218 The flow I usually use:

  1. export TRITON_KERNEL_DUMP=1 and clear the cache
  2. run the test: python ....
  3. Find the kernel you are interested in ~/.triton/dump/kernel_hash/
  4. copy the new cubin in ~/.triton/override/ with the same exact path as above ~/.triton/override/kernel_hash/kernel_name.cubin
  5. clear the cache and run TRITON_KERNEL_OVERRIDE=1 python....
  6. Check in the console for the message Overriding kernel with file kernel_name.cubin to make sure it happened

Have you tried using the latest ptxas as suggested by @jlebar? (we haven't upgraded to the latest one yet)

I haven't. Let me ask around how to get the latest.

You can see an example here: https://github.com/ThomasRaoux/triton/commit/edb74eb5bda10c7019f3717acb5d0c97eb2a411d

manman-ren commented 7 months ago

Tried with cuda_12.3.r12.3 ptxas, it generates the same code. I am not familiar with SASS, maybe it is possible that P0 is set implicitly by some instructions? But the incorrect outputs are all 0.0, which looks like the predicate is always choosing 0.0.

Jokeren commented 7 months ago

If it's very possible that there's a ptxas bug, we could prepare a reproducer and sent it to nvidia compiler folks. They usually respond quickly based on my experience.

manman-ren commented 7 months ago

@Jokeren Yes here is the repro and a README, let me know how to send it over to NVidia compiler folks. repro-ptx.txt README.txt

manman-ren commented 7 months ago

Filed as https://developer.nvidia.com/bugs/4474599

manman-ren commented 7 months ago

Reply from NVidia: This is a ptxas optimization bug that we've already fixed internally. It will be available in an update release of CUDA 12.4 soon. The optimization was incorrectly transforming a LOP3 (which supports predicate output) into a PRMT (which doesn't support predicate out).

Marks101 commented 7 months ago

Thanks to everyone involved for figuring this out 🙌

jpilaul commented 6 months ago

CUDA 12.4 is out. Has anyone tested it?

windsornguyen commented 2 months ago

Still running into the same issues with CUDA 12.4, torch=2.5.0.dev20240709+cu124, triton=2.3.1