triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
12.48k stars 1.51k forks source link

Illegal memory access in triton kernel and non-determinism codegen #1076

Open xw285cornell opened 1 year ago

xw285cornell commented 1 year ago

Hi team, we seem to identified an issue that triton is generating non-deterministic codegen results, and in certain cases (small chances) it seems to have illegal memory access due to out of bound access in shared memory. The triton code is fairly simple (shown below). We run the code many times (disabled pytorch's codegen cache) and it seems it produced different shared memory: Most of times: {"name": "triton0d1d2d3d4", "shared": 2560, "num_warps": 4, "num_stages": 1} Sometimes: {"name": "triton0d1d2d3d4", "shared": 544, "num_warps": 4, "num_stages": 1}

So sometimes it only asks for 544 bytes for shared memory space which might lead to out of bound access to shared mem. We checked and ttir is the same across the two runs, but llir is different. Wondering if you can shed some lights on this (and whether the codegen is deterministic. Also this is still on llvm IR - not sure if it repros in MLIR.

from torch._inductor.triton_ops.autotune import pointwise

@pointwise(
    size_hints=[262144, 64],
    tile_hint=TileHint.DEFAULT,
    filename="notebook",
    meta={
        "signature": {0: "*bf16", 1: "*fp32", 2: "*bf16", 3: "i32", 4: "i32"},
        "device": 0,
        "constants": {},
        "mutated_arg_names": [],
        "configs": [instance_descriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())],
    },
)
@triton.jit
def triton_(
    in_ptr0,
    in_ptr1,
    out_ptr0,
    xnumel,
    ynumel,
    XBLOCK: tl.constexpr,
    YBLOCK: tl.constexpr,
):
    xnumel = 262144
    ynumel = 62
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    yoffset = tl.program_id(1) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[None, :]
    ymask = yindex < ynumel
    x0 = xindex % 512
    x1 = xindex // 512
    y2 = yindex
    x3 = xindex
    tmp0 = tl.load(
        in_ptr0
        + (
            512
            + x0
            + (1536 * x1)
            + (786432 * y2)
            + (786432 * (((x0 + (512 * x1)) // 262144)))
        ),
        xmask & ymask,
    ).to(tl.float32)
    tmp1 = tl.load(in_ptr1 + (512 + x0), xmask)
    tmp2 = tmp1.to(tl.float32)
    tmp3 = tmp0 + tmp2
    tmp4 = 2.8284271247461903
    tmp5 = tmp3 / tmp4
    tl.store(
        out_ptr0 + (y2 + (62 * x3) + tl.zeros([XBLOCK, YBLOCK], tl.int32)),
        tmp5,
        xmask & ymask,
    )
Jokeren commented 1 year ago

Can you provide us with the input and output tensors? Thanks

xw285cornell commented 1 year ago

Input tensor is this:

args = [ torch.randn((31744, 1536), dtype=torch.bfloat16, device=device), torch.randn((1536), dtype=torch.float32, device=device), torch.randn((512, 8, 64, 62), dtype=torch.bfloat16, device=device), 262144, 62, ]

XBLOCK=32, YBLOCK=32

the illegal memory access was due to another bug in pytorch (race in cache). But it seems the non-determinism is there

ptillet commented 1 year ago

I couldn't repro with the new MLIR backend. It seems like we always have shared = 2176 now. I used the following script:

import torch
from torch._inductor.triton_ops.autotune import pointwise, TileHint
from triton.compiler import instance_descriptor
import triton
import triton.language as tl
import glob
import os
import shutil
import string

@pointwise(
    size_hints=[262144, 64],
    tile_hint=TileHint.DEFAULT,
    filename="notebook",
    meta={
        "signature": {0: "*bf16", 1: "*fp32", 2: "*bf16", 3: "i32", 4: "i32"},
        "device": 0,
        "constants": {},
        "mutated_arg_names": [],
        "configs": [instance_descriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())],
    },
)
@triton.jit
def triton_(
    in_ptr0,
    in_ptr1,
    out_ptr0,
    xnumel,
    ynumel,
    XBLOCK: tl.constexpr,
    YBLOCK: tl.constexpr,
):
    xnumel = 262144
    ynumel = 62
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    yoffset = tl.program_id(1) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[None, :]
    ymask = yindex < ynumel
    x0 = xindex % 512
    x1 = xindex // 512
    y2 = yindex
    x3 = xindex
    tmp0 = tl.load(
        in_ptr0
        + (
            512
            + x0
            + (1536 * x1)
            + (786432 * y2)
            + (786432 * (((x0 + (512 * x1)) // 262144)))
        ),
        xmask & ymask,
    ).to(tl.float32)
    tmp1 = tl.load(in_ptr1 + (512 + x0), xmask)
    tmp2 = tmp1.to(tl.float32)
    tmp3 = tmp0 + tmp2
    tmp4 = 2.8284271247461903
    tmp5 = tmp3 / tmp4
    tl.store(
        out_ptr0 + (y2 + (62 * x3) + tl.zeros([XBLOCK, YBLOCK], tl.int32)),
        tmp5,
        xmask & ymask,
    )

device = "cuda"
args = [
torch.randn((31744, 1536), dtype=torch.bfloat16, device=device),
torch.randn((1536), dtype=torch.float32, device=device),
torch.randn((512, 8, 64, 62), dtype=torch.bfloat16, device=device),
262144,
62,
]
XBLOCK=32
YBLOCK=32

stream = torch.cuda.current_stream()
pgm = triton_.fn[(1,1,1)](args[0], args[1], args[2], args[3], args[4], XBLOCK, YBLOCK)
print(pgm.shared)

and ran it 100 times, deleting the cache in between runs.

xw285cornell commented 1 year ago

thanks @ptillet! Yeah it may be a particular problem for the LLVM version before MLIR. Just curious - is it guaranteed to have determinism for the MLIR version?

I think we found out the root cause of the issue: the code generated in either 544 or 2560 shared mem are valid. But there is a race in writing the code cache. Basically triton write the IR and metadata (including shared_mem config) in separate transactions (each transaction is atomic). So if we have multiple processes doing the same thing, we may end of with code and metadata mismatch. Wondering if you consider to make commit to the code cache atomic?