triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.4k stars 1.64k forks source link

error: 'triton_gpu.cmpf' op requires the same encoding for all operands #1956

Closed akakakakakaa closed 1 year ago

akakakakakaa commented 1 year ago

I tried to run FlashAttention(not version 2) triton version source code by changing only tl.dot(q, k, trans_b=True) to tl.dot(q, tl.trans(k)).

But I received error,

error: 'triton_gpu.cmpf' op requires the same encoding for all operands

This is full source code for forward (only change tl.dot(q, k, trans_b=True) to tl.dot(q, tl.trans(k)) from flash attention source code)

@triton.heuristics(
    {
        "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
        "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
        "EVEN_HEADDIM": lambda args: args["headdim"] % args["BLOCK_HEADDIM"] == 0,
    }
)
@triton.jit(debug=True)
def _fwd_kernel(
    Q,
    K,
    V,
    Bias,
    Out,
    Lse,
    TMP,  # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
    softmax_scale,
    stride_qb,
    stride_qh,
    stride_qm,
    stride_kb,
    stride_kh,
    stride_kn,
    stride_vb,
    stride_vh,
    stride_vn,
    stride_bb,
    stride_bh,
    stride_bm,
    stride_ob,
    stride_oh,
    stride_om,
    nheads,
    seqlen_q,
    seqlen_k,
    seqlen_q_rounded,
    headdim,
    CACHE_KEY_SEQLEN_Q,
    CACHE_KEY_SEQLEN_K,
    BIAS_TYPE: tl.constexpr,
    IS_CAUSAL: tl.constexpr,
    BLOCK_HEADDIM: tl.constexpr,
    EVEN_M: tl.constexpr,
    EVEN_N: tl.constexpr,
    EVEN_HEADDIM: tl.constexpr,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
):
    start_m = tl.program_id(0)
    off_hb = tl.program_id(1)
    off_b = off_hb // nheads
    off_h = off_hb % nheads
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = tl.arange(0, BLOCK_N)
    offs_d = tl.arange(0, BLOCK_HEADDIM)
    # Initialize pointers to Q, K, V
    # Adding parenthesis around indexing might use int32 math instead of int64 math?
    # https://github.com/openai/triton/issues/741
    # I'm seeing a tiny bit of difference (5-7us)
    q_ptrs = (
        Q
        + off_b * stride_qb
        + off_h * stride_qh
        + (offs_m[:, None] * stride_qm + offs_d[None, :])
    )
    k_ptrs = (
        K
        + off_b * stride_kb
        + off_h * stride_kh
        + (offs_n[:, None] * stride_kn + offs_d[None, :])
    )
    v_ptrs = (
        V
        + off_b * stride_vb
        + off_h * stride_vh
        + (offs_n[:, None] * stride_vn + offs_d[None, :])
    )
    if BIAS_TYPE == "vector":
        b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n
    elif BIAS_TYPE == "matrix":
        b_ptrs = (
            Bias
            + off_b * stride_bb
            + off_h * stride_bh
            + (offs_m[:, None] * stride_bm + offs_n[None, :])
        )
    # initialize pointer to m and l
    t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m
    lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
    m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
    acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
    # load q: it will stay in SRAM throughout
    # [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call
    # tl.load(q_ptrs), we get the wrong output!
    if EVEN_M & EVEN_N:
        if EVEN_HEADDIM:
            q = tl.load(q_ptrs)
        else:
            q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
    else:
        if EVEN_HEADDIM:
            q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
        else:
            q = tl.load(
                q_ptrs,
                mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
                other=0.0,
            )
    # loop over k, v and update accumulator
    end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
    for start_n in range(0, end_n, BLOCK_N):
        start_n = tl.multiple_of(start_n, BLOCK_N)
        # -- compute qk ----
        if (
            EVEN_N & EVEN_M
        ):  # If we just do "if EVEN_N", there seems to be some race condition
            if EVEN_HEADDIM:
                k = tl.load(k_ptrs + start_n * stride_kn)
            else:
                k = tl.load(
                    k_ptrs + start_n * stride_kn,
                    mask=offs_d[None, :] < headdim,
                    other=0.0,
                )
        else:
            if EVEN_HEADDIM:
                k = tl.load(
                    k_ptrs + start_n * stride_kn,
                    mask=(start_n + offs_n)[:, None] < seqlen_k,
                    other=0.0,
                )
            else:
                k = tl.load(
                    k_ptrs + start_n * stride_kn,
                    mask=((start_n + offs_n)[:, None] < seqlen_k)
                    & (offs_d[None, :] < headdim),
                    other=0.0,
                )
        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
        qk += tl.dot(q, tl.trans(k))
        # Trying to combine the two masks seem to make the result wrong
        if not EVEN_N:  # Need to mask out otherwise the softmax is wrong
            qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
        if IS_CAUSAL:
            qk += tl.where(
                offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf")
            )
        if BIAS_TYPE != "none":
            if BIAS_TYPE == "vector":
                if EVEN_N:
                    bias = tl.load(b_ptrs + start_n).to(tl.float32)
                else:
                    bias = tl.load(
                        b_ptrs + start_n, mask=(start_n + offs_n) < seqlen_k, other=0.0
                    ).to(tl.float32)
                bias = bias[None, :]
            elif BIAS_TYPE == "matrix":
                if EVEN_M & EVEN_N:
                    bias = tl.load(b_ptrs + start_n).to(tl.float32)
                else:
                    bias = tl.load(
                        b_ptrs + start_n,
                        mask=(offs_m[:, None] < seqlen_q)
                        & ((start_n + offs_n)[None, :] < seqlen_k),
                        other=0.0,
                    ).to(tl.float32)
            # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
            # can then fuse the mult and add into an fma instruction. But if we have bias we need to
            # to multiply with softmax_scale here.
            qk = qk * softmax_scale + bias
            m_ij = tl.maximum(tl.max(qk, 1), lse_i)
            p = tl.exp(qk - m_ij[:, None])
        else:
            m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)
            p = tl.exp(qk * softmax_scale - m_ij[:, None])
        l_ij = tl.sum(p, 1)

        # scale acc_o
        acc_o_scale = tl.exp(m_i - m_ij)

        # # -- update output accumulator --
        # BUG: have to store and immediately load
        tl.store(t_ptrs, acc_o_scale)
        acc_o_scale = tl.load(t_ptrs)
        acc_o = acc_o * acc_o_scale[:, None]
        # update acc_o
        if (
            EVEN_N & EVEN_M
        ):  # If we just do "if EVEN_N", there seems to be some race condition
            if EVEN_HEADDIM:
                v = tl.load(v_ptrs + start_n * stride_vn)
            else:
                v = tl.load(
                    v_ptrs + start_n * stride_vn,
                    mask=offs_d[None, :] < headdim,
                    other=0.0,
                )
        else:
            if EVEN_HEADDIM:
                v = tl.load(
                    v_ptrs + start_n * stride_vn,
                    mask=(start_n + offs_n)[:, None] < seqlen_k,
                    other=0.0,
                )
            else:
                v = tl.load(
                    v_ptrs + start_n * stride_vn,
                    mask=((start_n + offs_n)[:, None] < seqlen_k)
                    & (offs_d[None, :] < headdim),
                    other=0.0,
                )
        p = p.to(v.dtype)
        acc_o += tl.dot(p, v)

        # -- update statistics
        m_i = m_ij
        l_i_new = tl.exp(lse_i - m_ij) + l_ij
        lse_i = m_ij + tl.log(l_i_new)

    o_scale = tl.exp(m_i - lse_i)
    # BUG: have to store and immediately load
    tl.store(t_ptrs, o_scale)
    o_scale = tl.load(t_ptrs)
    acc_o = acc_o * o_scale[:, None]
    # rematerialize offsets to save registers
    start_m = tl.program_id(0)
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    # write back l and m
    lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
    tl.store(lse_ptrs, lse_i)
    # initialize pointers to output
    offs_d = tl.arange(0, BLOCK_HEADDIM)
    out_ptrs = (
        Out
        + off_b * stride_ob
        + off_h * stride_oh
        + (offs_m[:, None] * stride_om + offs_d[None, :])
    )
    if EVEN_M:
        if EVEN_HEADDIM:
            tl.store(out_ptrs, acc_o)
        else:
            tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim)
    else:
        if EVEN_HEADDIM:
            tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)
        else:
            tl.store(
                out_ptrs,
                acc_o,
                mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
            )

stack trace is here.

loc(callsite("/home/mansu/anaconda3/envs/ft-track/lib/python3.8/site-packages/triton/language/core.py":1398:21 at "/home/mansu/huggingface/pingpong-llm-003-7b-chat-style-instruction-masking-3epoch/flash_attn_triton.py":222:45)): error: 'triton_gpu.cmpf' op requires the same encoding for all operands
Traceback (most recent call last):
  File "/home/mansu/anaconda3/envs/ft-track/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/mansu/anaconda3/envs/ft-track/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/mansu/.vscode-server/extensions/ms-python.python-2023.8.0/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/__main__.py", line 39, in <module>
    cli.main()
  File "/home/mansu/.vscode-server/extensions/ms-python.python-2023.8.0/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 430, in main
    run()
  File "/home/mansu/.vscode-server/extensions/ms-python.python-2023.8.0/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 317, in run_module
    run_module_as_main(options.target, alter_argv=True)
  File "/home/mansu/.vscode-server/extensions/ms-python.python-2023.8.0/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 238, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/mansu/.vscode-server/extensions/ms-python.python-2023.8.0/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 124, in _run_code
    exec(code, run_globals)
  File "/home/mansu/huggingface/pingpong-llm-003-7b-chat-style-instruction-masking-3epoch/compile_test.py", line 412, in <module>
    test(model, input_ids=input_ids, attention_mask=attention_mask)
  File "/home/mansu/huggingface/pingpong-llm-003-7b-chat-style-instruction-masking-3epoch/compile_test.py", line 342, in test
    _ = model(**kwargs)
  File "/home/mansu/anaconda3/envs/ft-track/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/mansu/huggingface/pingpong-llm-003-7b-chat-style-instruction-masking-3epoch/compile_test.py", line 217, in forward
    (x, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal)
  File "/home/mansu/anaconda3/envs/ft-track/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/mansu/huggingface/pingpong-llm-003-7b-chat-style-instruction-masking-3epoch/blocks.py", line 36, in forward
    (b, _, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal)
  File "/home/mansu/anaconda3/envs/ft-track/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/mansu/huggingface/pingpong-llm-003-7b-chat-style-instruction-masking-3epoch/attention.py", line 172, in forward
    (context, attn_weights) = self.attn_fn(query, key, value, self.n_heads, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights)
  File "/home/mansu/huggingface/pingpong-llm-003-7b-chat-style-instruction-masking-3epoch/attention.py", line 112, in triton_flash_attn_fn
    attn_output = flash_attn_triton.flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
  File "/home/mansu/anaconda3/envs/ft-track/lib/python3.8/site-packages/torch/autograd/function.py", line 506, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/mansu/huggingface/pingpong-llm-003-7b-chat-style-instruction-masking-3epoch/flash_attn_triton.py", line 1201, in forward
    o, lse, ctx.softmax_scale = _flash_attn_forward(
  File "/home/mansu/huggingface/pingpong-llm-003-7b-chat-style-instruction-masking-3epoch/flash_attn_triton.py", line 915, in _flash_attn_forward
    _fwd_kernel[grid](
  File "/home/mansu/anaconda3/envs/ft-track/lib/python3.8/site-packages/triton/runtime/autotuner.py", line 232, in run
    return self.fn.run(*args, **kwargs)
  File "<string>", line 63, in _fwd_kernel
  File "/home/mansu/anaconda3/envs/ft-track/lib/python3.8/site-packages/triton/compiler/compiler.py", line 495, in compile
    next_module = compile_kernel(module)
  File "/home/mansu/anaconda3/envs/ft-track/lib/python3.8/site-packages/triton/compiler/compiler.py", line 402, in <lambda>
    lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps), num_stages, arch))
  File "/home/mansu/anaconda3/envs/ft-track/lib/python3.8/site-packages/triton/compiler/compiler.py", line 91, in optimize_ttgir
    pm.run(mod)
RuntimeError: PassManager::run failed

Can you advise me how to debug triton code? I can't find any problem in 222 line source code.

                    bias = tl.load(
                        b_ptrs + start_n,
                        mask=(offs_m[:, None] < seqlen_q)
                        & ((start_n + offs_n)[None, :] < seqlen_k), => this part
                        other=0.0,
                    ).to(tl.float32)

Thanks.

cat538 commented 1 year ago

same problem, does anyone solve it?

CyanHillFox commented 1 year ago

Got the same problem, the line number reported by triton compiler may be unrelated to this error, since I even dont pass bias to this kernel. I looked into verifySameEncoding https://github.com/openai/triton/blob/6dee55c912ad8320ebc63e69a77bb83c19d1f19e/lib/Dialect/Triton/IR/Traits.cpp#L8 but did't figure out what is the encoding of RankedTensorType.

ThomasRaoux commented 1 year ago

is this still happening? I suspect this has been fixed, I'm not able to run the example you sent, could you please include the full example or the IR before the pass causing the problem.

Skylion007 commented 1 year ago

@ThomasRaoux Can you repro the issue from this PR: https://github.com/Dao-AILab/flash-attention/pull/458 ?

ThomasRaoux commented 1 year ago

what command line do you run? running python flash_attn_triton.py nothing happens

Skylion007 commented 1 year ago

@ThomasRaoux ~pytest -q -s tests/test_flash_attn.py will run it through the gpt-2 model tests, benchmark will run the external triton implementation if that is what you are asking. https://github.com/Dao-AILab/flash-attention/blob/866a9d33f9bcab0742d007c720ade1e1b79d1d79/benchmarks/benchmark_flash_attention.py#L81~

Let me ask around.

ThomasRaoux commented 1 year ago

The problem is most likely fixed with https://github.com/openai/triton/commit/d4644d6cb3ae674e1f15932cac1f28104795744f. Can you try to see if it repro with top of tree? If you are able to repro, sharing the IR might be the simplest way for me or someone else to look at it.

jimmieliu commented 1 year ago

I tested with triton_nightly-2.1.0.dev20230822000928-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl, and this "triton_gpu.cmpf" it is fixed.

jinzhen-lin commented 1 year ago

I found that if I comment out these two lines, the code can run and the calculation result is correct.

        tl.store(t_ptrs, acc_o_scale)
        acc_o_scale = tl.load(t_ptrs)