Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
11.77k stars 1.04k forks source link

benchmarks/benchmark_causal.py RuntimeError #274

Open mit10000 opened 1 year ago

mit10000 commented 1 year ago

Used Nvidia docker PyTorch Release 22.04 python 3.8.13 torch 1.12.0a0+bd13bc6' flash-attn 1.0.7 triton 2.0.0 GPU : NVIDIA GeForce RTX 2080 Ti cuda 11.4

command python benchmarks/benchmark_causal.py

FlashAttention - Forward pass Traceback (most recent call last): File "benchmarks/benchmark_causal.py", line 91, in benchmark_all(flash_attn_unpadded_qkvpacked_func, rearrange(qkv, 'b s ... -> (b s) ...'), File "/opt/conda/lib/python3.8/site-packages/flash_attn/utils/benchmark.py", line 89, in benchmark_all benchmark_forward(fn, inputs, repeats=repeats, desc=desc, verbose=verbose, File "/opt/conda/lib/python3.8/site-packages/flash_attn/utils/benchmark.py", line 17, in benchmark_forward fn_amp(inputs, *kwinputs) File "/opt/conda/lib/python3.8/site-packages/flash_attn/utils/benchmark.py", line 15, in fn_amp fn(inputs, *kwinputs) File "/opt/conda/lib/python3.8/site-packages/flash_attn/flash_attn_interface.py", line 256, in flash_attn_unpadded_qkvpacked_func return FlashAttnQKVPackedFunc.apply(qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, File "/opt/conda/lib/python3.8/site-packages/flash_attn/flash_attn_interface.py", line 58, in forward out, softmax_lse, rng_state, S_dmask = _flash_attn_forward( File "/opt/conda/lib/python3.8/site-packages/flash_attn/flash_attn_interface.py", line 21, in _flash_attn_forward softmax_lse, rng_state, rest = flash_attn_cuda.fwd( RuntimeError: Expected q_dtype == torch::kFloat16 || ((is_sm8x || is_sm90) && q_dtype == torch::kBFloat16) to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.)

Could you please help ?

tridao commented 1 year ago

RTX 2080 hardware doesn't support bfloat16. You can try changing the dtype to torch.float16 (i.e. fp16).

mit10000 commented 1 year ago

Thank you so much @tridao for your help. I just changed it to dtype = torch.float16

and errors become

FlashAttention - Forward pass <torch.utils.benchmark.utils.common.Measurement object at 0x7f5d2ee00280> fn_amp(*inputs, kwinputs) 1.29 ms 1 measurement, 30 runs , 8 threads FlashAttention - Backward pass <torch.utils.benchmark.utils.common.Measurement object at 0x7f5ca34443a0> y.backward(grad, retain_graph=True) 1.36 ms 1 measurement, 30 runs , 8 threads FlashAttention - Forward + Backward pass <torch.utils.benchmark.utils.common.Measurement object at 0x7f5ca4948790> f(grad, *inputs, *kwinputs) 1.84 ms 1 measurement, 30 runs , 8 threads PyTorch Attention - Forward pass <torch.utils.benchmark.utils.common.Measurement object at 0x7f5ca345dfa0> fn_amp(inputs, kwinputs) 7.04 ms 1 measurement, 30 runs , 8 threads PyTorch Attention - Backward pass <torch.utils.benchmark.utils.common.Measurement object at 0x7f5ca3438a60> y.backward(grad, retain_graph=True) 8.31 ms 1 measurement, 30 runs , 8 threads PyTorch Attention - Forward + Backward pass <torch.utils.benchmark.utils.common.Measurement object at 0x7f5ca34443a0> f(grad, *inputs, **kwinputs) 15.33 ms 1 measurement, 30 runs , 8 threads FlashAttention Triton - Forward pass Traceback (most recent call last): File "", line 21, in _fwd_kernel KeyError: ('2-.-0-.-0-1e8410f206c822547fb50e2ea86e45a6-d6252949da17ceb5f3a278a70250af13-3b85c7bef5f0a641282f3b73af50f599-14de7de5c4da5794c8ca14e7e41a122d-3498c340fd4b6ee7805fd54b882a04f5-e1f133f98d04093da2078dfc51c36b72-b26258bf01f839199e39d64851821f26-d7c06e3b46e708006c15224aac7a1378-f585402118c8a136948ce0a49cfe122c', (torch.float16, torch.float16, torch.float16, None, torch.float16, torch.float32, torch.float32, 'fp32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32'), ('none', True, 32, True, True, True, 128, 128), (True, True, True, (False,), True, True, True, (False,), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (False, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False)))

During handling of the above exception, another exception occurred:

Traceback (most recent call last): File "/opt/conda/lib/python3.8/site-packages/triton/compiler.py", line 937, in build_triton_ir generator.visit(fn.parse()) File "/opt/conda/lib/python3.8/site-packages/triton/compiler.py", line 855, in visit return super().visit(node) File "/opt/conda/lib/python3.8/ast.py", line 371, in visit return visitor(node) File "/opt/conda/lib/python3.8/site-packages/triton/compiler.py", line 183, in visit_Module ast.NodeVisitor.generic_visit(self, node) File "/opt/conda/lib/python3.8/ast.py", line 379, in generic_visit self.visit(item) File "/opt/conda/lib/python3.8/site-packages/triton/compiler.py", line 855, in visit return super().visit(node) File "/opt/conda/lib/python3.8/ast.py", line 371, in visit return visitor(node) File "/opt/conda/lib/python3.8/site-packages/triton/compiler.py", line 252, in visit_FunctionDef has_ret = self.visit_compound_statement(node.body) File "/opt/conda/lib/python3.8/site-packages/triton/compiler.py", line 177, in visit_compound_statement self.last_ret_type = self.visit(stmt) File "/opt/conda/lib/python3.8/site-packages/triton/compiler.py", line 855, in visit return super().visit(node) File "/opt/conda/lib/python3.8/ast.py", line 371, in visit return visitor(node) File "/opt/conda/lib/python3.8/site-packages/triton/compiler.py", line 678, in visit_For self.visit_compound_statement(node.body) File "/opt/conda/lib/python3.8/site-packages/triton/compiler.py", line 177, in visit_compound_statement self.last_ret_type = self.visit(stmt) File "/opt/conda/lib/python3.8/site-packages/triton/compiler.py", line 855, in visit return super().visit(node) File "/opt/conda/lib/python3.8/ast.py", line 371, in visit return visitor(node) File "/opt/conda/lib/python3.8/site-packages/triton/compiler.py", line 319, in visit_AugAssign self.visit(assign) File "/opt/conda/lib/python3.8/site-packages/triton/compiler.py", line 855, in visit return super().visit(node) File "/opt/conda/lib/python3.8/ast.py", line 371, in visit return visitor(node) File "/opt/conda/lib/python3.8/site-packages/triton/compiler.py", line 301, in visit_Assign values = self.visit(node.value) File "/opt/conda/lib/python3.8/site-packages/triton/compiler.py", line 855, in visit return super().visit(node) File "/opt/conda/lib/python3.8/ast.py", line 371, in visit return visitor(node) File "/opt/conda/lib/python3.8/site-packages/triton/compiler.py", line 339, in visit_BinOp rhs = self.visit(node.right) File "/opt/conda/lib/python3.8/site-packages/triton/compiler.py", line 855, in visit return super().visit(node) File "/opt/conda/lib/python3.8/ast.py", line 371, in visit return visitor(node) File "/opt/conda/lib/python3.8/site-packages/triton/compiler.py", line 797, in visit_Call return fn(*args, _builder=self.builder, *kws) File "/opt/conda/lib/python3.8/site-packages/triton/impl/base.py", line 22, in wrapper return fn(args, **kwargs) TypeError: dot() got an unexpected keyword argument 'trans_b'

The above exception was the direct cause of the following exception:

Traceback (most recent call last): File "benchmarks/benchmark_causal.py", line 96, in benchmark_all(flash_attn_qkvpacked_func, qkv, None, causal, repeats=repeats, desc='FlashAttention Triton') File "/opt/conda/lib/python3.8/site-packages/flash_attn/utils/benchmark.py", line 89, in benchmark_all benchmark_forward(fn, inputs, repeats=repeats, desc=desc, verbose=verbose, File "/opt/conda/lib/python3.8/site-packages/flash_attn/utils/benchmark.py", line 17, in benchmark_forward fn_amp(inputs, kwinputs) File "/opt/conda/lib/python3.8/site-packages/flash_attn/utils/benchmark.py", line 15, in fn_amp fn(*inputs, *kwinputs) File "/opt/conda/lib/python3.8/site-packages/flash_attn/flash_attn_triton.py", line 733, in forward o, lse, ctx.softmax_scale = _flash_attn_forward( File "/opt/conda/lib/python3.8/site-packages/flash_attn/flash_attn_triton.py", line 623, in _flash_attn_forward _fwd_kernel[grid]( File "/opt/conda/lib/python3.8/site-packages/triton/runtime/autotuner.py", line 199, in run return self.fn.run(args, kwargs) File "", line 41, in _fwd_kernel File "/opt/conda/lib/python3.8/site-packages/triton/compiler.py", line 1621, in compile next_module = compile(module) File "/opt/conda/lib/python3.8/site-packages/triton/compiler.py", line 1550, in lambda src: ast_to_ttir(src, signature, configs[0], constants)), File "/opt/conda/lib/python3.8/site-packages/triton/compiler.py", line 962, in ast_tottir mod, = build_triton_ir(fn, signature, specialization, constants) File "/opt/conda/lib/python3.8/site-packages/triton/compiler.py", line 942, in build_triton_ir raise CompilationError(fn.src, node) from e triton.compiler.CompilationError: at 78:24: def _fwd_kernel( Q, K, V, Bias, Out, Lse, TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug softmax_scale, stride_qb, stride_qh, stride_qm, stride_kb, stride_kh, stride_kn, stride_vb, stride_vh, stride_vn, stride_bb, stride_bh, stride_bm, stride_ob, stride_oh, stride_om, nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, BIAS_TYPE: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): start_m = tl.program_id(0) off_hb = tl.program_id(1) off_b = off_hb // nheads off_h = off_hb % nheads

off_b = tl.program_id(1)

# off_h = tl.program_id(2)
# off_hb = off_b * nheads + off_h
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_HEADDIM)
# Initialize pointers to Q, K, V
# Adding parenthesis around indexing might use int32 math instead of int64 math?
# https://github.com/openai/triton/issues/741
# I'm seeing a tiny bit of difference (5-7us)
q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])
k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])
v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])
if BIAS_TYPE == 'vector':
    b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n
elif BIAS_TYPE == 'matrix':
    b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + (offs_m[:, None] * stride_bm + offs_n[None, :])
# initialize pointer to m and l
t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m
lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
# load q: it will stay in SRAM throughout
# [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call
# tl.load(q_ptrs), we get the wrong output!
if EVEN_M & EVEN_N:
    if EVEN_HEADDIM:
        q = tl.load(q_ptrs)
    else:
        q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
else:
    if EVEN_HEADDIM:
        q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
    else:
        q = tl.load(q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
                    other=0.0)
# loop over k, v and update accumulator
end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
for start_n in range(0, end_n, BLOCK_N):
    start_n = tl.multiple_of(start_n, BLOCK_N)
    # -- compute qk ----
    if EVEN_N & EVEN_M:  # If we just do "if EVEN_N", there seems to be some race condition
        if EVEN_HEADDIM:
            k = tl.load(k_ptrs + start_n * stride_kn)
        else:
            k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)
    else:
        if EVEN_HEADDIM:
            k = tl.load(k_ptrs + start_n * stride_kn, mask=(start_n + offs_n)[:, None] < seqlen_k,
                        other=0.0)
        else:
            k = tl.load(k_ptrs + start_n * stride_kn,
                        mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
                        other=0.0)
    qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
    qk += tl.dot(q, k, trans_b=True)
                    ^
tridao commented 1 year ago

Yeah I don't think Triton supports RTX 2080 very well. Even then, you'd need the triton version mentioned in the file.

mit10000 commented 1 year ago

Could you please let me know what should I do? Buy a new GPU is the only way? Thanks!

mit10000 commented 1 year ago

Yeah I don't think Triton supports RTX 2080 very well. Even then, you'd need the triton version mentioned in the file.

Thanks, I will install triton version in the file.

tridao commented 1 year ago

FlashAttention works for RTX 2080 (except backward pass for headdim > 64). This file just benchmarks different implementations (FlashAttention in CUDA, Pytorch, FlashAttention Triton).