triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
12.95k stars 1.57k forks source link

Print statements inside kernel print incorrect value of int64 tensors #4060

Open georg-wolflein opened 4 months ago

georg-wolflein commented 4 months ago

I came across a bug using int64 tensors. Here's a minimal reproduction.

MWE:

import torch
import triton
import triton.language as tl

@triton.jit
def ndscore_kernel(ptr):
    value = tl.load(ptr)
    print("value in kernel", value)
    tl.store(ptr, value + 1)

ptr = torch.tensor(42, dtype=torch.int64).cuda()
print("value before kernel", ptr.item())
ndscore_kernel[(1,)](ptr)
print("value after kernel", ptr.item())

Output:

value before kernel 42
pid (0, 0, 0) idx () value in kernel: 0
[...]
pid (0, 0, 0) idx () value in kernel: 0
value after kernel 43

Why does the kernel print 0 instead of 42?

Observations:

jlebar commented 4 months ago

Thank you for the bug report, this looks real, I will see if I can have a look.

karthik-man commented 4 months ago

I am able to repro this and will work on fixing this.

karthik-man commented 4 months ago

My initial analysis: The issue seems to be with the alignment for the store that sets up 'value' for the vprintf.

The llir for this store has an alignment of 4 %2 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09@$2 ld.global.b64 { $0 }, [ $1 + 0 ];", "=l,l,b"(ptr addrspace(1) %0, i1 true) #2, !dbg !12 ... store i64 %2, ptr %9, align 4

This results in the value being split across two ptx stores @%p1 ld.global.b64 { %rd1 }, [ %rd2 + 0 ]; ... st.local.u32 [%r4+12], %rd1; shr.u64 %rd6, %rd1, 32; st.local.u32 [%r4+16], %rd6;

I am still not clear about how this split results in 0 being printed. I see that the alignment is 8 with NVCC https://godbolt.org/z/Y81rafYPo

llir/ptx files 4060_i64.ptx.txt 4060_i64.llir.txt

htyu commented 4 months ago

Thanks for the detailed analysis! I suspect the alignment of the struct variable passed into vprintf might be wrong.

In your PTX code, the print format (variable printfFormat_0) has the value of

"pid (%u, %u, %u) idx () value in kernel: %llu\n"

which is passed in as the first parameter of vprintf. The second parameter is the address of the struct object, where the last field corresponds to the int64 value. But that field has an alignment of 4 and I think vprint expects 8, so the lower half of the value is actually considered as padding and only the higher part is printed.

karthik-man commented 3 months ago

The kernel prints the right value if I run with DISABLE_LLVM_OPT=1.

I see that a GEP rewrite optimization(https://github.com/llvm/llvm-project/blob/90ba33099cbb17e7c159e9ebc5a512037db99d6d/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp#L2456 ) is converting the GEP to 'value', with an index 3 into the arg struct passed to vprintf; to a GEP with offset of 12 into that struct. When I checked to see if the DataLayout passed to the optimization is incorrect, I saw that the DataLayout used while translating to LLVMIR seems to have the wrong alignment for Int64: If I print llvmModule.getDataLayout().getABIIntegerTypeAlignment(64) inside mlir::translateModuleToLLVMIR(...), I see (llvm::Align) $2 = (ShiftValue = '\x02'). This means int64 alignment specified in the the DataLayout is 4. I am tracking down the cause for this.