Open mit10000 opened 1 year ago
RTX 2080 hardware doesn't support bfloat16. You can try changing the dtype to torch.float16
(i.e. fp16).
Thank you so much @tridao for your help. I just changed it to dtype = torch.float16
and errors become
FlashAttention - Forward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7f5d2ee00280>
fn_amp(*inputs, kwinputs)
1.29 ms
1 measurement, 30 runs , 8 threads
FlashAttention - Backward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7f5ca34443a0>
y.backward(grad, retain_graph=True)
1.36 ms
1 measurement, 30 runs , 8 threads
FlashAttention - Forward + Backward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7f5ca4948790>
f(grad, *inputs, *kwinputs)
1.84 ms
1 measurement, 30 runs , 8 threads
PyTorch Attention - Forward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7f5ca345dfa0>
fn_amp(inputs, kwinputs)
7.04 ms
1 measurement, 30 runs , 8 threads
PyTorch Attention - Backward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7f5ca3438a60>
y.backward(grad, retain_graph=True)
8.31 ms
1 measurement, 30 runs , 8 threads
PyTorch Attention - Forward + Backward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7f5ca34443a0>
f(grad, *inputs, **kwinputs)
15.33 ms
1 measurement, 30 runs , 8 threads
FlashAttention Triton - Forward pass
Traceback (most recent call last):
File "
During handling of the above exception, another exception occurred:
Traceback (most recent call last): File "/opt/conda/lib/python3.8/site-packages/triton/compiler.py", line 937, in build_triton_ir generator.visit(fn.parse()) File "/opt/conda/lib/python3.8/site-packages/triton/compiler.py", line 855, in visit return super().visit(node) File "/opt/conda/lib/python3.8/ast.py", line 371, in visit return visitor(node) File "/opt/conda/lib/python3.8/site-packages/triton/compiler.py", line 183, in visit_Module ast.NodeVisitor.generic_visit(self, node) File "/opt/conda/lib/python3.8/ast.py", line 379, in generic_visit self.visit(item) File "/opt/conda/lib/python3.8/site-packages/triton/compiler.py", line 855, in visit return super().visit(node) File "/opt/conda/lib/python3.8/ast.py", line 371, in visit return visitor(node) File "/opt/conda/lib/python3.8/site-packages/triton/compiler.py", line 252, in visit_FunctionDef has_ret = self.visit_compound_statement(node.body) File "/opt/conda/lib/python3.8/site-packages/triton/compiler.py", line 177, in visit_compound_statement self.last_ret_type = self.visit(stmt) File "/opt/conda/lib/python3.8/site-packages/triton/compiler.py", line 855, in visit return super().visit(node) File "/opt/conda/lib/python3.8/ast.py", line 371, in visit return visitor(node) File "/opt/conda/lib/python3.8/site-packages/triton/compiler.py", line 678, in visit_For self.visit_compound_statement(node.body) File "/opt/conda/lib/python3.8/site-packages/triton/compiler.py", line 177, in visit_compound_statement self.last_ret_type = self.visit(stmt) File "/opt/conda/lib/python3.8/site-packages/triton/compiler.py", line 855, in visit return super().visit(node) File "/opt/conda/lib/python3.8/ast.py", line 371, in visit return visitor(node) File "/opt/conda/lib/python3.8/site-packages/triton/compiler.py", line 319, in visit_AugAssign self.visit(assign) File "/opt/conda/lib/python3.8/site-packages/triton/compiler.py", line 855, in visit return super().visit(node) File "/opt/conda/lib/python3.8/ast.py", line 371, in visit return visitor(node) File "/opt/conda/lib/python3.8/site-packages/triton/compiler.py", line 301, in visit_Assign values = self.visit(node.value) File "/opt/conda/lib/python3.8/site-packages/triton/compiler.py", line 855, in visit return super().visit(node) File "/opt/conda/lib/python3.8/ast.py", line 371, in visit return visitor(node) File "/opt/conda/lib/python3.8/site-packages/triton/compiler.py", line 339, in visit_BinOp rhs = self.visit(node.right) File "/opt/conda/lib/python3.8/site-packages/triton/compiler.py", line 855, in visit return super().visit(node) File "/opt/conda/lib/python3.8/ast.py", line 371, in visit return visitor(node) File "/opt/conda/lib/python3.8/site-packages/triton/compiler.py", line 797, in visit_Call return fn(*args, _builder=self.builder, *kws) File "/opt/conda/lib/python3.8/site-packages/triton/impl/base.py", line 22, in wrapper return fn(args, **kwargs) TypeError: dot() got an unexpected keyword argument 'trans_b'
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "benchmarks/benchmark_causal.py", line 96, in
# off_h = tl.program_id(2)
# off_hb = off_b * nheads + off_h
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_HEADDIM)
# Initialize pointers to Q, K, V
# Adding parenthesis around indexing might use int32 math instead of int64 math?
# https://github.com/openai/triton/issues/741
# I'm seeing a tiny bit of difference (5-7us)
q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])
k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])
v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])
if BIAS_TYPE == 'vector':
b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n
elif BIAS_TYPE == 'matrix':
b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + (offs_m[:, None] * stride_bm + offs_n[None, :])
# initialize pointer to m and l
t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m
lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
# load q: it will stay in SRAM throughout
# [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call
# tl.load(q_ptrs), we get the wrong output!
if EVEN_M & EVEN_N:
if EVEN_HEADDIM:
q = tl.load(q_ptrs)
else:
q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
else:
if EVEN_HEADDIM:
q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
else:
q = tl.load(q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
other=0.0)
# loop over k, v and update accumulator
end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
for start_n in range(0, end_n, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition
if EVEN_HEADDIM:
k = tl.load(k_ptrs + start_n * stride_kn)
else:
k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)
else:
if EVEN_HEADDIM:
k = tl.load(k_ptrs + start_n * stride_kn, mask=(start_n + offs_n)[:, None] < seqlen_k,
other=0.0)
else:
k = tl.load(k_ptrs + start_n * stride_kn,
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k, trans_b=True)
^
Yeah I don't think Triton supports RTX 2080 very well. Even then, you'd need the triton version mentioned in the file.
Could you please let me know what should I do? Buy a new GPU is the only way? Thanks!
Yeah I don't think Triton supports RTX 2080 very well. Even then, you'd need the triton version mentioned in the file.
Thanks, I will install triton version in the file.
FlashAttention works for RTX 2080 (except backward pass for headdim > 64). This file just benchmarks different implementations (FlashAttention in CUDA, Pytorch, FlashAttention Triton).
Used Nvidia docker PyTorch Release 22.04 python 3.8.13 torch 1.12.0a0+bd13bc6' flash-attn 1.0.7 triton 2.0.0 GPU : NVIDIA GeForce RTX 2080 Ti cuda 11.4
command python benchmarks/benchmark_causal.py
FlashAttention - Forward pass Traceback (most recent call last): File "benchmarks/benchmark_causal.py", line 91, in benchmark_all(flash_attn_unpadded_qkvpacked_func, rearrange(qkv, 'b s ... -> (b s) ...'), File "/opt/conda/lib/python3.8/site-packages/flash_attn/utils/benchmark.py", line 89, in benchmark_all benchmark_forward(fn, inputs, repeats=repeats, desc=desc, verbose=verbose, File "/opt/conda/lib/python3.8/site-packages/flash_attn/utils/benchmark.py", line 17, in benchmark_forward fn_amp(inputs, *kwinputs) File "/opt/conda/lib/python3.8/site-packages/flash_attn/utils/benchmark.py", line 15, in fn_amp fn(inputs, *kwinputs) File "/opt/conda/lib/python3.8/site-packages/flash_attn/flash_attn_interface.py", line 256, in flash_attn_unpadded_qkvpacked_func return FlashAttnQKVPackedFunc.apply(qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, File "/opt/conda/lib/python3.8/site-packages/flash_attn/flash_attn_interface.py", line 58, in forward out, softmax_lse, rng_state, S_dmask = _flash_attn_forward( File "/opt/conda/lib/python3.8/site-packages/flash_attn/flash_attn_interface.py", line 21, in _flash_attn_forward softmax_lse, rng_state, rest = flash_attn_cuda.fwd( RuntimeError: Expected q_dtype == torch::kFloat16 || ((is_sm8x || is_sm90) && q_dtype == torch::kBFloat16) to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.)
Could you please help ?