triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
12.76k stars 1.54k forks source link

Incorrect result from first-run autotune function #781

Open jgong5 opened 1 year ago

jgong5 commented 1 year ago

Repro:

import torch
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream

import triton
import triton.language as tl

@triton.autotune(
    configs=[
        triton.Config({'XBLOCK': 1, 'RBLOCK': 64}),
        triton.Config({'XBLOCK': 1, 'RBLOCK': 64}),
    ],
    key=['xnumel', 'rnumel',],
)
@triton.jit
def kernel(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr4, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
    xnumel = 392
    rnumel = 384
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK, 1])
    xmask = xindex < xnumel
    rbase = tl.reshape(tl.arange(0, RBLOCK), [1, RBLOCK])
    x0 = xindex % 196
    x1 = (xindex // 196)
    _tmp3 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0
    x3 = xindex
    _tmp6 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r2 = rindex
        tmp0 = tl.load(in_ptr0 + (x0 + (196*r2) + (75264*x1)), xmask & rmask, eviction_policy='evict_last')
        tmp1 = tl.load(in_ptr1 + (r2), rmask, eviction_policy='evict_last')
        tmp4 = tl.load(in_ptr2 + (x0 + (196*r2) + (75264*x1)), xmask & rmask, eviction_policy='evict_last')
        tmp2 = tmp0 * tmp1
        _tmp3 = tl.where(xmask & rmask, _tmp3 + tmp2, _tmp3)
        tmp5 = tmp2 * tmp4
        _tmp6 = tl.where(xmask & rmask, _tmp6 + tmp5, _tmp6)
    tmp3 = tl.reshape(tl.sum(_tmp3, 1), [XBLOCK, 1])
    tmp6 = tl.reshape(tl.sum(_tmp6, 1), [XBLOCK, 1])
    tmp8 = tl.load(in_ptr4 + (x3), xmask)
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r2 = rindex
        tmp7 = tl.load(in_out_ptr0 + (r2 + (384*x3)), xmask & rmask, eviction_policy='evict_last')
        tmp9 = tl.load(in_ptr0 + (x0 + (196*r2) + (75264*x1)), xmask & rmask, eviction_policy='evict_last')
        tmp10 = tl.load(in_ptr1 + (r2), rmask, eviction_policy='evict_last')
        tmp15 = tl.load(in_ptr2 + (x0 + (196*r2) + (75264*x1)), xmask & rmask, eviction_policy='evict_last')
        tmp11 = tmp9 * tmp10
        tmp12 = 384
        tmp13 = tmp11 * tmp12
        tmp14 = tmp13 - tmp3
        tmp16 = tmp15 * tmp6
        tmp17 = tmp14 - tmp16
        tmp18 = tmp8 * tmp17
        tmp19 = tmp7 + tmp18
        tl.store(in_out_ptr0 + (r2 + (384*x3) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp19, xmask & rmask)

if __name__ == "__main__":
    def rand_strided(size, stride, dtype=torch.float32, device="cpu"):
        needed_size = sum((shape - 1) * stride for shape, stride in zip(size, stride)) + 1
        if dtype.is_floating_point:
            buffer = torch.randn(needed_size, dtype=dtype, device=device)
        else:
            buffer = torch.ones(size=[needed_size], dtype=dtype, device=device)
        return torch.as_strided(buffer, size, stride)
    torch.manual_seed(1337)

    def grid(xnumel):
        from triton import cdiv
        """Helper function to compute triton grids"""
        def grid_fn(meta):
            return (
                cdiv(xnumel, meta["XBLOCK"]),
                1,
                1,
            )
        return grid_fn

    div_3 = rand_strided((2, 196, 1), (196, 1, 392), device='cuda', dtype=torch.float32)
    mul = rand_strided((2, 196, 384), (75264, 1, 196), device='cuda', dtype=torch.float32)
    primals_3 = rand_strided((384, ), (1, ), device='cuda', dtype=torch.float32)
    buf34 = rand_strided((2, 384, 196), (75264, 196, 1), device='cuda', dtype=torch.float32)
    buf39_1 = rand_strided((2, 196, 384), (75264, 384, 1), device='cuda', dtype=torch.float32)
    buf39_2 = buf39_1.clone()
    buf39_3 = buf39_1.clone()

    stream0 = get_cuda_stream(0)
    kernel.run(buf39_1, buf34, primals_3, mul, div_3, 392, 384, grid=grid(392), stream=stream0)
    kernel.run(buf39_2, buf34, primals_3, mul, div_3, 392, 384, grid=grid(392), stream=stream0)
    kernel.run(buf39_3, buf34, primals_3, mul, div_3, 392, 384, grid=grid(392), stream=stream0)

    assert torch.allclose(buf39_2, buf39_3, atol=0.001, rtol=0.001, equal_nan=True)
    assert torch.allclose(buf39_1, buf39_2, atol=0.001, rtol=0.001, equal_nan=True)

Expected result: no assertion occurs but got failure on the second assertion. If only one triton.Config is left (i.e. only one triton.Config({'XBLOCK': 1, 'RBLOCK': 64})), the test can pass. Check https://github.com/pytorch/torchdynamo/issues/1670 for more background info.

jgong5 commented 1 year ago

Repro with triton commit: af76c989eb4799b015f8b288ccd8421558772e56

jgong5 commented 1 year ago

cc @jansel

ptillet commented 1 year ago

I don't think it's technically a bug. The auto-tuner will just run all the configs one by one and see which one is fastest. In this case, you have:

        tmp7 = tl.load(in_out_ptr0 + (r2 + (384*x3)), xmask & rmask, eviction_policy='evict_last')
        ...
        tl.store(in_out_ptr0 + (r2 + (384*x3) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp19, xmask & rmask)

so when the kernel runs twice, the first call modifies in_out_ptr0 which leads to the wrong value being loaded in the second call.

The same issue arised when doing atomic_add in a matmul with SPLIT_K>1, and we fixed the issue by adding a hook in the auto-tuner that resets the desired arguments to zeros between runs: https://github.com/openai/triton-mlir/blob/main/python/triton/runtime/autotuner.py#L25-L33

I imagine in your case we'd want the hook to initialize data from a copy.

Re-tagging this as an enhancement as we probably want to revisit the autotuner API and document this behavior properly. I don't think it'll be possible to have a general analysis pass that determines which arguments should be copied upfront -- this would work in simple cases but likely would fail when aliasing behavior is dynamic.

jansel commented 1 year ago

Ah I think you are correct. The autotuner runs the kernel multiple times, which is not safe to do for in-place kernels because the first run clobbers the input data.

Note TorchInductor has a subclass of the autotuner, so for this case we would need to modify https://github.com/pytorch/torchdynamo/blob/6380959be21851bfda99424392cc08fda29d073d/torchinductor/triton_ops/autotune.py#L145 to create copies of all the input tensors using: https://github.com/pytorch/torchdynamo/blob/6380959be21851bfda99424392cc08fda29d073d/torchinductor/compile_fx.py#L145

jgong5 commented 1 year ago

@ptillet Thanks for the insights. Appreciate if you can update the doc about the limitation and the recommendation to mitigate the problem before you provide a complete solution. Memory-related issues are hard to troubleshoot. If Triton can report some warnings via some debugging flag, that would make user's life easier.