triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.56k stars 1.67k forks source link

Assertion failure in Linear Layouts when num_warps = 8, but passes with num_warps = 4 #5265

Closed Moerafaat closed 16 hours ago

Moerafaat commented 5 days ago

Describe the bug

To reproduce the issue, you can run the following python test:

import torch
import triton
import triton.language as tl

@triton.jit
def repro_kernel(q_ref,
               k_ref,
               v_ref,
               output_ptr,
               ):
    offsets64 = tl.arange(0, 64)
    offsets128 = tl.arange(0, 128)
    q = tl.load(q_ref + (offsets64[:, None] * 128 + offsets128[None, :]))
    k = tl.load(k_ref + (offsets128[:, None] * 64 + offsets64[None, :]))
    qk = tl.dot(q, k).to(tl.bfloat16)
    v = tl.load(v_ref + (offsets64[:, None] * 128 + offsets128[None, :]))
    o = tl.dot(qk, v)
    tl.store(output_ptr + (offsets64[:, None] * 128 + offsets128[None, :]), o.to(tl.bfloat16))

def repro(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
    output = torch.empty((64, 128), dtype=torch.bfloat16, device='cuda')
    grid = lambda meta: (1, 1)
    k = repro_kernel[grid](q, k, v, output, num_warps=8, num_ctas=1, num_stages=3)
    # print(k.asm['ttir'])
    return output

torch.manual_seed(0)
q = torch.ones((64, 128), dtype=torch.bfloat16, device='cuda')
k = torch.ones((128, 64), dtype=torch.bfloat16, device='cuda')
v = torch.ones((64, 128), dtype=torch.bfloat16, device='cuda')
output_torch = (q @ k) @ v
output_triton = repro(q, k, v)
print(output_torch)
print(output_triton)
print(f'The maximum difference between torch and triton is '
      f'{torch.max(torch.abs(output_torch - output_triton))}')

You will encounter the following error:

python3: /tmp/triton/lib/Tools/LinearLayout.cpp:526: mlir::triton::LinearLayout mlir::triton::LinearLayout::reshapeOuts(llvm::ArrayRef<std::pair<mlir::StringAttr, int> >) const: Assertion `getTotalOutDimSize() == std::accumulate( newOutDims.begin(), newOutDims.end(), 1, [&](int32_t acc, auto &outDim) { return acc * outDim.second; })' failed.
Aborted

I notice that there was a similar report here https://github.com/triton-lang/triton/issues/4727 before the issue was re-opened. Interestingly, the failure actually started happening with the commit that was linked to that issue. The culprit commit is https://github.com/triton-lang/triton/commit/49266aa908d29a7029348ff480a75a3ea4d6e704

The test passes if num_warps are set to 4 instead of 8, and used to work properly before the culprit commit.

Environment details

The issue reproduces on H100 with the latest Triton main: commit 8b29bb752033c8db578bf88624a6dc91dc8f45cf

Jokeren commented 5 days ago

Interesting. Taking a look now.

4727 is TMA so it's not related.

Jokeren commented 5 days ago

FYI, I have a solution works for it now with stmatrix. Will upstream soon

where out dims are: [offset (size 4096), iteration (size 1)]
tensor([[8192., 8192., 8192.,  ..., 8192., 8192., 8192.],
        [8192., 8192., 8192.,  ..., 8192., 8192., 8192.],
        [8192., 8192., 8192.,  ..., 8192., 8192., 8192.],
        ...,
        [8192., 8192., 8192.,  ..., 8192., 8192., 8192.],
        [8192., 8192., 8192.,  ..., 8192., 8192., 8192.],
        [8192., 8192., 8192.,  ..., 8192., 8192., 8192.]], device='cuda:0',
       dtype=torch.bfloat16)
tensor([[8192., 8192., 8192.,  ..., 8192., 8192., 8192.],
        [8192., 8192., 8192.,  ..., 8192., 8192., 8192.],
        [8192., 8192., 8192.,  ..., 8192., 8192., 8192.],
        ...,
        [8192., 8192., 8192.,  ..., 8192., 8192., 8192.],
        [8192., 8192., 8192.,  ..., 8192., 8192., 8192.],
        [8192., 8192., 8192.,  ..., 8192., 8192., 8192.]], device='cuda:0',
       dtype=torch.bfloat16)
Moerafaat commented 5 days ago

Thanks! Really appreciate the fast reply on this and looking forward to your fix 🙏

Jokeren commented 5 days ago

https://github.com/triton-lang/triton/pull/5277 is a partial fix. More general fixes will be pushed next week

Moerafaat commented 4 days ago

https://github.com/triton-lang/triton/pull/5277 is a partial fix.

Tested it and it works great! Thanks for the fast turn-around!

Moerafaat commented 16 hours ago

Marking this fixed. Thanks for the assistance!