ROCm / triton

Development repository for the Triton language and compiler
MIT License
80 stars 23 forks source link

Compatible with PyTorch's triton #507

Closed netw0rkf10w closed 3 months ago

netw0rkf10w commented 5 months ago

Hello,

I have installed PyTorch 2.3.0.dev20240211+rocm5.7, which included Triton 3.0.0+dafe145982. Could you please tell me if we could use the flash attention implementation of this repo with that version of Triton? If we have to install the version of this repo, then will that break PyTorch's installation?

Thank you very much in advance for your answer!

zhanglx13 commented 5 months ago

I'm not aware that triton 3.0 is released.
@jataylo Could you chime in about this combination?

netw0rkf10w commented 5 months ago

@zhanglx13 If I install a previous version of PyTorch (e.g., 2.1.2) that comes with Triton 2, then do I have to install this repo to have flash attention working? Thank you in advance for your answer.

jataylo commented 5 months ago

Hey @netw0rkf10w @zhanglx13

3.0 version comes in for us to maintain the same triton version as the regular openai/nvidia triton wheel produced in pytorch, but it's actually using an earlier commit of the triton-mlir branch. We may want to rethink this versioning or consider upgrading the triton version number in our branch so this is less confusing.

On the flash attention repo, the question is whether this branch https://github.com/ROCm/triton/tree/pytorch_nightly/23_11_2023 supports it, which was cutoff from triton-mlir branch on Nov 3rd @zhanglx13.

netw0rkf10w commented 5 months ago

@jataylo Thanks for the information.

FYI, I've tested with PyTorch nightly (pytorch-triton-rocm==3.0.0+dafe145982) and it didn't work:

Traceback (most recent call last):
  File "/home/user/code/benchmark_flash_attention.py", line 174, in <module>
    ex()
  File "/home/user/code/benchmark_flash_attention.py", line 106, in ex
    r4_fp16 = flash_attention_triton(q.half(), k.half(), v.half())
  File "/home/user/code/benchmark_flash_attention.py", line 70, in flash_attention_triton
    tri_out, _ = attention(q, k, v, None, metadata)
  File "/home/user/.local/lib/python3.10/site-packages/torch/autograd/function.py", line 572, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/user/code/triton/python/perf-kernels/flash-attention.py", line 801, in forward
    attn_fwd[grid](
  File "<string>", line 74, in attn_fwd
  File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/compiler.py", line 552, in compile
    next_module = compile_kernel(module)
  File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/compiler.py", line 427, in <lambda>
    lambda src: optimize_ttir(ast_to_ttir(src, signature, configs[0], constants, debug=debug, arch=arch), arch))
  File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1142, in ast_to_ttir
    generator.visit(fn.parse())
  File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1025, in visit
    ret = super().visit(node)
  File "/opt/cray/pe/python/3.10.10/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 292, in visit_Module
    ast.NodeVisitor.generic_visit(self, node)
  File "/opt/cray/pe/python/3.10.10/lib/python3.10/ast.py", line 426, in generic_visit
    self.visit(item)
  File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1025, in visit
    ret = super().visit(node)
  File "/opt/cray/pe/python/3.10.10/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 362, in visit_FunctionDef
    self.visit_compound_statement(node.body)
  File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 287, in visit_compound_statement
    ret_type = self.visit(stmt)
  File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1025, in visit
    ret = super().visit(node)
  File "/opt/cray/pe/python/3.10.10/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 414, in visit_Assign
    values = self.visit(node.value)
  File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1025, in visit
    ret = super().visit(node)
  File "/opt/cray/pe/python/3.10.10/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 621, in visit_IfExp
    raise UnsupportedLanguageConstruct(
triton.compiler.errors.UnsupportedLanguageConstruct: at 45:14:        # small for all start_m so for those we return early.
        if start_m * BLOCK_M > seqlen_q:
            return
        cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)
        cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)
        seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start
    else:
        cu_seqlens_q_start = 0
        cu_seqlens_k_start = 0
        seqlen_q = max_seqlens_q
        seqlen_k = max_seqlens_k
    off_h_k = off_h_q % hk if is_mqa else off_h_q
zhanglx13 commented 5 months ago

@netw0rkf10w Can you try install triton from this repo following the steps

git clone https://github.com/ROCm/triton.git
cd triton/python
pip install -e .

It should update Triton to the tot and work with pytorch 2.1, 2.2, and 2.3. Note that if you want AMD f8 data type support, it has to be pytorch 2.3.

netw0rkf10w commented 5 months ago

@zhanglx13 I installed from this repo and now I got the following errors when running the code:

Traceback (most recent call last):
  File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1222, in ast_to_ttir
    generator.visit(fn.parse())
  File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1104, in visit
    ret = super().visit(node)
  File "/opt/cray/pe/python/3.10.10/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 299, in visit_Module
    ast.NodeVisitor.generic_visit(self, node)
  File "/opt/cray/pe/python/3.10.10/lib/python3.10/ast.py", line 426, in generic_visit
    self.visit(item)
  File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1104, in visit
    ret = super().visit(node)
  File "/opt/cray/pe/python/3.10.10/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 372, in visit_FunctionDef
    self.visit_compound_statement(node.body)
  File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 294, in visit_compound_statement
    ret_type = self.visit(stmt)
  File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1104, in visit
    ret = super().visit(node)
  File "/opt/cray/pe/python/3.10.10/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 641, in visit_If
    self.visit_compound_statement(node.body)
  File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 294, in visit_compound_statement
    ret_type = self.visit(stmt)
  File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1104, in visit
    ret = super().visit(node)
  File "/opt/cray/pe/python/3.10.10/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 628, in visit_If
    self.visit_if_scf(cond, node)
  File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 597, in visit_if_scf
    self.visit_then_else_blocks(node, liveins, then_block, else_block)
  File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 501, in visit_then_else_blocks
    self.visit_compound_statement(node.body)
  File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 294, in visit_compound_statement
    ret_type = self.visit(stmt)
  File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1104, in visit
    ret = super().visit(node)
  File "/opt/cray/pe/python/3.10.10/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 424, in visit_Assign
    values = self.visit(node.value)
  File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1104, in visit
    ret = super().visit(node)
  File "/opt/cray/pe/python/3.10.10/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1020, in visit_Call
    return self.call_JitFunction(fn, args, kws)
  File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 988, in call_JitFunction
    generator.visit(fn.parse())
  File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1104, in visit
    ret = super().visit(node)
  File "/opt/cray/pe/python/3.10.10/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 299, in visit_Module
    ast.NodeVisitor.generic_visit(self, node)
  File "/opt/cray/pe/python/3.10.10/lib/python3.10/ast.py", line 426, in generic_visit
    self.visit(item)
  File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1104, in visit
    ret = super().visit(node)
  File "/opt/cray/pe/python/3.10.10/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 372, in visit_FunctionDef
    self.visit_compound_statement(node.body)
  File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 294, in visit_compound_statement
    ret_type = self.visit(stmt)
  File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1104, in visit
    ret = super().visit(node)
  File "/opt/cray/pe/python/3.10.10/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 884, in visit_For
    self.visit_compound_statement(node.body)
  File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 294, in visit_compound_statement
    ret_type = self.visit(stmt)
  File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1104, in visit
    ret = super().visit(node)
  File "/opt/cray/pe/python/3.10.10/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 447, in visit_AugAssign
    self.visit(assign)
  File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1104, in visit
    ret = super().visit(node)
  File "/opt/cray/pe/python/3.10.10/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 424, in visit_Assign
    values = self.visit(node.value)
  File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1104, in visit
    ret = super().visit(node)
  File "/opt/cray/pe/python/3.10.10/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 476, in visit_BinOp
    rhs = self.visit(node.right)
  File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1104, in visit
    ret = super().visit(node)
  File "/opt/cray/pe/python/3.10.10/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1026, in visit_Call
    return fn(*args, **extra_kwargs, **kws)
  File "/home/user/.local/lib/python3.10/site-packages/triton/language/core.py", line 27, in wrapper
    return fn(*args, **kwargs)
  File "/home/user/.local/lib/python3.10/site-packages/triton/language/core.py", line 1059, in dot
    return semantic.dot(input, other, acc, allow_tf32, max_num_imprecise_acc, out_dtype, _builder)
  File "/home/user/.local/lib/python3.10/site-packages/triton/language/semantic.py", line 1271, in dot
    and rhs.shape[1].value >= 4, \
AssertionError: All values in both first input shape ([constexpr[256], constexpr[8]]) and second input shape ([constexpr[8], constexpr[64]]) must be >= 4!

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/user/code/benchmark_flash_attention.py", line 173, in <module>
    ex()
  File "/home/user/code/benchmark_flash_attention.py", line 106, in ex
    r4_fp16 = flash_attention_triton(q.half(), k.half(), v.half())
  File "/home/user/code/benchmark_flash_attention.py", line 70, in flash_attention_triton
    tri_out, _ = attention(q, k, v, None, metadata)
  File "/home/user/.local/lib/python3.10/site-packages/torch/autograd/function.py", line 572, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/user/code/torching/src/torching/flash_attention/flash_attention_src.py", line 801, in forward
    attn_fwd[grid](
  File "/home/user/.local/lib/python3.10/site-packages/triton/runtime/jit.py", line 555, in run
    self.cache[device][key] = compile(
  File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/compiler.py", line 591, in compile
    next_module = compile_kernel(module)
  File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/compiler.py", line 455, in <lambda>
    ast_to_ttir(src, signature, configs[0], constants, debug=debug, target=target), target))
  File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1231, in ast_to_ttir
    raise CompilationError(fn.src, node, repr(e)) from e
triton.compiler.errors.CompilationError: at 145:16:        if seqlen_k >= BLOCK_N:
            acc, l_i, m_i = _attn_fwd_inner(
                acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
                start_m, seqlen_aligned,
                dropout_p, philox_seed, batch_philox_offset, encoded_softmax_block_ptr,
                BLOCK_M, BLOCK_DMODEL, BLOCK_N,
                4 - STAGE, offs_m, offs_n,
                PRE_LOAD_V,
                False, seqlen_aligned,
                bias_ptr,
                ENABLE_DROPOUT,
                RETURN_ENCODED_SOFTMAX
                ^
AssertionError('All values in both first input shape ([constexpr[256], constexpr[8]]) and second input shape ([constexpr[8], constexpr[64]]) must be >= 4!')
zhanglx13 commented 5 months ago

It seems the tile size is too small. It's doing (256x8) x (8x64) but lhs.shape[1] must be >= 16. It's hard to say if this is a bug. Can you post the script you are running?

netw0rkf10w commented 5 months ago

Oh, increasing D_HEAD from 8 to 16 works! I wonder why the head dimension has to be >=16, because there's no such constraint in the original CUDA implementation...

netw0rkf10w commented 5 months ago

So apparently D_HEAD should be high because otherwise it's even slower than the naive implementation.

Below are some benchmark results for batch_size = 16, sequence_length = 256, num_heads in [8, 16, 32], and dims = [64, 128]:

[----------------------------- FlashAttention -----------------------------]
                                 |  sdp_attention  |  flash  |  flash-triton
1 threads: -----------------------------------------------------------------
      [torch.float16, 8, 64]     |       274.9     |  142.5  |     317.6    
      [torch.float16, 8, 128]    |       354.2     |  245.8  |     365.2    
      [torch.float16, 16, 64]    |       412.4     |  200.2  |     355.3    
      [torch.float16, 16, 128]   |       688.1     |  380.9  |     434.0    
      [torch.float16, 32, 64]    |       790.3     |  375.4  |     243.9    
      [torch.float16, 32, 128]   |      1356.3     |  743.4  |     325.8    
      [torch.bfloat16, 8, 64]    |       350.7     |  117.2  |     388.1    
      [torch.bfloat16, 8, 128]   |       396.5     |  255.8  |     431.1    
      [torch.bfloat16, 16, 64]   |       548.1     |  196.6  |     453.4    
      [torch.bfloat16, 16, 128]  |       771.7     |  384.4  |     540.1    
      [torch.bfloat16, 32, 64]   |      1071.6     |  368.4  |     355.8    
      [torch.bfloat16, 32, 128]  |      1519.3     |  744.2  |     497.0    

Times are in microseconds (us).

The triton version is faster than the naive implementation only starting from 16 heads with dimension >= 64, and it's only faster than the other implementation (https://github.com/ROCm/flash-attention) starting from 32 heads with dimension >= 64 (the speedup in the case 32, 128) is very impressive though).

Could you tell me if these results correspond to what you have in mind regarding these implementations?

zhanglx13 commented 5 months ago

@netw0rkf10w I'll assume [torch.float16, 8, 64] means 8 heads with head_dim = 64. Here are my observations from your benchmark, without knowing the GPU (seems this is on MI200 GPUs) you are using and block size and/or other tuning parameters.

netw0rkf10w commented 5 months ago

Hi @zhanglx13. Sorry for the late reply, my server has been broken over the last days so I've been waiting for it to be fixed before giving some more feedback.

First of all, your last commit (https://github.com/ROCm/triton/commit/35edd6a650e3f6a56e3c2db9a54fdc2f8a6505e1) leads to numerically incorrect results. I had to revert back to before that commit.

Second, below are some benchmarking results on an MI250x:

[------------------------------------------ FlashAttention ------------------------------------------]
                                           |  sdp_attention  |  torch sdpa  |   flash   |  triton-rocm
1 threads: -------------------------------------------------------------------------------------------
      [torch.float16, 12, 64, 256, 64]     |       872.7     |     741.4    |    234.8  |      510.6  
      [torch.float16, 12, 64, 256, 128]    |      1645.0     |    1463.1    |    466.6  |      509.3  
      [torch.float16, 16, 64, 256, 64]     |      1105.7     |     985.4    |    314.6  |      605.1  
      [torch.float16, 16, 64, 256, 128]    |      2192.0     |    1948.7    |    628.1  |      683.3  
      [torch.float16, 16, 64, 576, 64]     |      6016.5     |    5045.2    |   1864.9  |     1771.6  
      [torch.float16, 16, 64, 576, 128]    |     12003.2     |   10068.3    |   3715.1  |     3533.0  
      [torch.float16, 16, 128, 256, 64]    |      2246.9     |    2251.9    |    708.6  |      876.5  
      [torch.float16, 16, 128, 256, 128]   |      4476.0     |    4481.0    |   1390.7  |     1173.0  
      [torch.float16, 16, 128, 784, 64]    |     18295.2     |   16747.2    |   6723.6  |     5801.8  
      [torch.float16, 16, 128, 784, 128]   |     36606.3     |   33475.4    |  13345.4  |    11521.5  
      [torch.bfloat16, 12, 64, 256, 64]    |      1022.0     |     883.2    |    224.9  |      754.5  
      [torch.bfloat16, 12, 64, 256, 128]   |      1934.5     |    1743.9    |    438.5  |     1006.2  
      [torch.bfloat16, 16, 64, 256, 64]    |      1296.5     |    1170.6    |    299.1  |      956.1  
      [torch.bfloat16, 16, 64, 256, 128]   |      2574.3     |    2320.5    |    580.9  |     1357.3  
      [torch.bfloat16, 16, 64, 576, 64]    |      6028.7     |    4973.1    |   1771.1  |     2963.7  
      [torch.bfloat16, 16, 64, 576, 128]   |     12071.5     |   10050.5    |   3492.6  |     5946.5  
      [torch.bfloat16, 16, 128, 256, 64]   |      2197.3     |    2200.5    |    690.4  |     1244.4  
      [torch.bfloat16, 16, 128, 256, 128]  |      4382.2     |    4385.6    |   1345.2  |     1910.6  
      [torch.bfloat16, 16, 128, 784, 64]   |     16343.5     |   14736.8    |   6560.7  |    10183.0  
      [torch.bfloat16, 16, 128, 784, 128]  |     32673.9     |   29445.0    |  13011.1  |    20128.2  

The first columns correspond to the variables [dtype, num_heads, head_dim, seq_len, batch_size]. The above values of num_heads, head_dim, seq_len correspond roughly to the configurations of ViT base, large, and huge (the head dimension of ViT-huge is 80, which is not supported by flash-attn, so I used 128 instead).

PyTorch was not compiled with flash attention or memory efficient attention (UserWarning: 1Torch was not compiled with flash attention. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:264.)) so the torch sdpa column should be there only for reference and do not reflect the true performance of torch.nn.functional.scaled_dot_product_attention once the compilation issue has been fixed (by your colleague at AMD).

From the results, the triton version slightly outperforms the ROCm version (flash) on long sequence lengths in FP16. And the ROCm version completely outperforms the triton version for BFloat16!

netw0rkf10w commented 5 months ago

And moreover, for your information, the above results are way way behind those on an NVIDIA A100:

[--------------------------------------- FlashAttention ---------------------------------------]
                                           |  sdp_attention  |  torch sdpa  |  flash   |  triton
1 threads: -------------------------------------------------------------------------------------
      [torch.float16, 12, 64, 256, 64]     |       518.2     |      98.1    |   119.9  |   264.7
      [torch.float16, 12, 64, 256, 128]    |       940.1     |     185.7    |   178.1  |   202.4
      [torch.float16, 16, 64, 256, 64]     |       639.5     |     127.9    |   121.5  |   289.8
      [torch.float16, 16, 64, 256, 128]    |      1240.0     |     244.5    |   233.3  |   266.9
      [torch.float16, 16, 64, 576, 64]     |      3374.3     |     695.5    |   641.5  |   744.5
      [torch.float16, 16, 64, 576, 128]    |      6707.5     |    1380.6    |  1272.6  |  1229.3
      [torch.float16, 16, 128, 256, 64]    |       735.9     |     242.1    |   230.7  |   416.9
      [torch.float16, 16, 128, 256, 128]   |      1431.7     |     453.4    |   412.4  |   526.6
      [torch.float16, 16, 128, 784, 64]    |      6392.1     |    2041.1    |  1915.0  |  2397.9
      [torch.float16, 16, 128, 784, 128]   |     12817.3     |    4054.9    |  3803.6  |  4520.8
      [torch.bfloat16, 12, 64, 256, 64]    |       493.7     |      97.7    |    92.7  |        
      [torch.bfloat16, 12, 64, 256, 128]   |       950.7     |     183.3    |   175.2  |        
      [torch.bfloat16, 16, 64, 256, 64]    |       645.7     |     125.7    |   119.4  |        
      [torch.bfloat16, 16, 64, 256, 128]   |      1253.9     |     240.0    |   229.1  |        
      [torch.bfloat16, 16, 64, 576, 64]    |      3143.5     |     686.3    |   632.4  |        
      [torch.bfloat16, 16, 64, 576, 128]   |      6245.4     |    1360.7    |  1247.5  |        
      [torch.bfloat16, 16, 128, 256, 64]   |       739.7     |     228.6    |   207.6  |        
      [torch.bfloat16, 16, 128, 256, 128]  |      1435.2     |     442.5    |   402.7  |        
      [torch.bfloat16, 16, 128, 784, 64]   |      6451.8     |    2009.5    |  1875.7  |        
      [torch.bfloat16, 16, 128, 784, 128]  |     12877.8     |    3987.9    |  3732.9  |        
zhanglx13 commented 5 months ago

@netw0rkf10w I tried to reproduce the perf numbers with MI250X and here is what I got

case   sdp_attention torch sdpa flash triton-us triton-us tuned
1 [torch.float16, 12, 64, 256, 64] 872.7 741.4 234.8 510.6 176.4
1 [torch.float16, 12, 64, 256, 128] 1645 1463.1 466.6 509.3 322.4
1 [torch.float16, 16, 64, 256, 64] 1105.7 985.4 314.6 605.1 225.4
1 [torch.float16, 16, 64, 256, 128] 2192 1948.7 628.1 683.3 420.0
2 [torch.float16, 16, 64, 576, 64] 6016.5 5045.2 1864.9 1771.6 1504.8
2 [torch.float16, 16, 64, 576, 128] 12003.2 10068.3 3715.1 3533 2998.4
3 [torch.float16, 16, 128, 256, 64] 2246.9 2251.9 708.6 876.5 557.2
3 [torch.float16, 16, 128, 256, 128] 4476 4481 1390.7 1173 1070.4
4 [torch.float16, 16, 128, 784, 64] 18295.2 16747.2 6723.6 5801.8  
4 [torch.float16, 16, 128, 784, 128] 36606.3 33475.4 13345.4 11521.5  

As you can see, some tuning is needed to bring up the perf.

For bfloat16, there is a known issue with conversion between f32 and bf16 in the kv-loop, hence the perf for bf16 is worse. We are working on this issue.

For numerical errors with my last commit (https://github.com/ROCm/triton/commit/35edd6a650e3f6a56e3c2db9a54fdc2f8a6505e1), I tried some configs and it works. However, it's possible that some other configs fail since that PR is primarily for gemm kernels and was not extensively tested for FA. Could you provide the failed configs in FA so that I can take a look?

netw0rkf10w commented 5 months ago

Hi @zhanglx13!

Thanks a lot for your reply! The results with tuning look very promising!

I'm trying to get some more results before replying, but I encountered an issue that I haven't been able to resolve:

Traceback (most recent call last):
  File "/home/user/code/tests/benchmark_flash_attention.py", line 271, in <module>
    ex()
  File "/home/user/code/tests/benchmark_flash_attention.py", line 163, in ex
    triton_rocm_fp16 = flash_attention_triton_rocm(q.half(), k.half(), v.half())
  File "/home/user/code/tests/benchmark_flash_attention.py", line 113, in flash_attention_triton_rocm
    output = flash_attn_triton_rocm(q, k, v)
  File "/home/user/code/torching/src/torching/flash_attention/triton_rocm/flash_attention.py", line 916, in attention
    output, _ = _attention.apply(q, k, v, None, metadata)
  File "/home/user/.local/lib/python3.10/site-packages/torch/autograd/function.py", line 553, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/user/code/torching/src/torching/flash_attention/triton_rocm/flash_attention.py", line 805, in forward
    attn_fwd[grid](
  File "/home/user/.local/lib/python3.10/site-packages/triton/runtime/jit.py", line 581, in run
    bin.c_wrapper(
  File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/compiler.py", line 745, in __getattribute__
    self._init_handles()
  File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/compiler.py", line 736, in _init_handles
    mod, func, n_regs, n_spills = fn_load_binary(self.metadata["name"], self.asm[bin_path], self.shared, device)
SystemError: <built-in function load_binary> returned NULL without setting an exception

It's happened to me quite often, and each time I had to remove and re-install everything again for it to work. This has become increasing frustrating to the point that I wanted to give up Triton altogether. Do you happen you know how to fix this issue please?

zhanglx13 commented 5 months ago

I understand it's very frustrating and I think it happened to me once as well.

File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/compiler.py

It looks like you are not using the triton installed from this repo. Here is what you can try

  1. Remove the pre-installed triton: pip uninstall triton
  2. Go to the python dir of this repo: cd python. This could make a difference since this is where the setup.py lives.
  3. Install this triton: pip install -e .

If you still see such error that can be fixed by re-installing everything, let us know, and @jataylo from Pytorch can have more insight about the issue.

netw0rkf10w commented 5 months ago

File "/home/user/.local/lib/python3.10/site-packages/triton/compiler/compiler.py

It looks like you are not using the triton installed from this repo. Here is what you can try

Sorry that was because I tried both pip install -e . --user and pip install . --user, with the latter being tried after the former had failed to fix the issue. I only used this repo for the installation. By the way, pip uninstall triton never worked for me:

$ pip uninstall triton
Found existing installation: triton 2.1.0
Can't uninstall 'triton'. No files were found to uninstall.

For uninstallation, I had to do:

rm -rf ~/.local/lib/python3.10/site-packages/triton* 
rm ~/.local/lib/python3.10/site-packages/__editable__*triton* 
rm -rf ~/code/triton/build ~/code/triton/python/triton.egg-info

But now uninstalling and re-installing Triton doesn't work anymore, the error remains. I don't know why :(

netw0rkf10w commented 5 months ago

Hmm ok, so the last re-installation didn't work because I forgot to git checkout d6f14d374e0a02ce113687954accbb3526b3324f, that is, to remove your last commit. Now it's working again!

What annoys me is that now it has numerical issues:

B = 5; H = 4; L = 10; D = 16

max diff r1_fp32 vs r1_fp16 = 0.0014356374740600586
max diff r1_fp32 vs r1_bf16 = 0.015220880508422852
max diff r1_fp32 vs r2_fp32 = 0.0
max diff r1_fp32 vs r2_fp16 = 0.0014356374740600586
max diff r1_fp32 vs r2_bf16 = 0.015220880508422852
max diff r1_fp32 vs r3_fp16 = 0.001074075698852539
max diff r1_fp32 vs r3_bf16 = 0.010561943054199219
max diff r1_fp32 vs triton_rocm_fp16 = 4.576810836791992
max diff r1_fp32 vs triton_rocm_bf16 = 4.579740524291992

This didn't happen before (it only happened with your most recent commit, as I mentioned earlier).

For larger sizes the results seem good:

For B = 64; H = 12; L = 256; D = 64:

max diff r1_fp32 vs r1_fp16 = 0.0020145773887634277
max diff r1_fp32 vs r1_bf16 = 0.01580953598022461
max diff r1_fp32 vs r2_fp32 = 8.344650268554688e-07
max diff r1_fp32 vs r2_fp16 = 0.0019515752792358398
max diff r1_fp32 vs r2_bf16 = 0.013421595096588135
max diff r1_fp32 vs r3_fp16 = 0.0009750127792358398
max diff r1_fp32 vs r3_bf16 = 0.00976407527923584
max diff r1_fp32 vs triton_rocm_fp16 = 0.0013276338577270508
max diff r1_fp32 vs triton_rocm_bf16 = 0.009515345096588135

Full script for reproducing:

import sys, os
import torch
import math
import torch.utils.benchmark as benchmark
from torch.nn.functional import scaled_dot_product_attention
import warnings

if not torch.cuda.is_available():
    warnings.warn("CUDA not available, Flash attention will not be used.")
    HAS_FLASH = False
    HAS_FLASH_TRITON_CUDA = False
    HAS_FLASH_TRITON_ROCM = False
else:
    if torch.cuda.get_device_capability('cuda') < (7, 5):
        warnings.warn("GPU not supported, Flash attention will not be used.")
        HAS_FLASH = False
        HAS_FLASH_TRITON_CUDA = False
        HAS_FLASH_TRITON_CUDA_DAO = False
        HAS_FLASH_TRITON_ROCM = False
    else:
        try:
            from flash_attn import flash_attn_func
            HAS_FLASH = True
        except ImportError:
            warnings.warn("Cannot import flash_attn")
            HAS_FLASH = False
        try:
            from torching.flash_attention import flash_attn_triton
            HAS_FLASH_TRITON_CUDA = True
        except ImportError:
            warnings.warn("Cannot import flash_attn_triton_dao")
            HAS_FLASH_TRITON_CUDA = False
        try:
            from torching.flash_attention import flash_attn_triton_dao
            HAS_FLASH_TRITON_CUDA_DAO = True
        except ImportError:
            warnings.warn("Cannot import flash_attn_triton_dao")
            HAS_FLASH_TRITON_CUDA_DAO = False
        try:
            from torching.flash_attention import flash_attn_triton_rocm
            HAS_FLASH_TRITON_ROCM = True
        except ImportError:
            warnings.warn("Cannot import flash_attn_triton_rocm")
            HAS_FLASH_TRITON_ROCM = False

def sdp_attention(q, k, v, rel_pos=None):
    """
    q: (B, H, L, D)
    """
    # with torch.cuda.amp.autocast(enabled=True, dtype=q.dtype):
    scale = 1 / math.sqrt(q.size(-1))
    attn = (q @ k.transpose(-2, -1)) * scale # (B, H, L, L)
    if rel_pos is not None:
        attn = attn + rel_pos

    attn = attn.softmax(dim=-1) # (B, H, L, L)
    # attn = self.attn_drop(attn) 
    result = attn @ v
    # (B, H, L, L) x (B, H, L, D)  --> (B, H, L, D) --> (B, L, H, D)
    return result

def flash_attention(q, k, v):
    """
    shape: (B, H, L, D)
    """
    # flash attention expects (B, L, H, d)
    q = q.transpose(1, 2)
    k = k.transpose(1, 2)
    v = v.transpose(1, 2)
    result = flash_attn_func(
                        q, k, v,
                        dropout_p=0,
                        softmax_scale=None,
                        causal=False,
                        return_attn_probs=False,
                        )
    return result.transpose(1, 2)

def flash_attention_triton(q, k, v):
    """
    shape: (B, H, seqlen_q, D_HEAD)
    """
    output = flash_attn_triton(q, k, v)
    return output

def flash_attention_triton_dao(q, k, v):
    """
    shape: (B, H, seqlen_q, D_HEAD)
    """
    output = flash_attn_triton_dao(q, k, v)
    return output

def flash_attention_triton_rocm(q, k, v):
    """
    shape: (B, H, seqlen_q, D_HEAD)
    """
    output = flash_attn_triton_rocm(q, k, v)
    return output

def forward_backward(q, k, v, method=''):
    """
    forward + backward
    q: (B, H, L, D)
    """
    # forward
    output = sdp_attention(q, k, v)
    # backward
    grad = torch.rand(output.size()).cuda()
    output.backward(grad)
    return output

def ex():
    # (B, H, L, D)
    B = 5; H = 4; L = 10; D = 16
    # B = 5; H = 64; L = 128; D = 128
    # B = 64; H = 12; L = 256; D = 64

    q = torch.randn((B, H, L, D)).to('cuda')
    k = torch.randn((B, H, L, D)).to('cuda')
    v = torch.randn((B, H, L, D)).to('cuda')

    # bias = torch.randn((1, H, L, L))
    r1_fp32 = scaled_dot_product_attention(q, k, v)
    r1_fp16 = scaled_dot_product_attention(q.half(), k.half(), v.half())
    r1_bf16 = scaled_dot_product_attention(q.bfloat16(), k.bfloat16(), v.bfloat16())
    print(f"max diff r1_fp32 vs r1_fp16 = {(r1_fp32 - r1_fp16).abs().max()}")
    print(f"max diff r1_fp32 vs r1_bf16 = {(r1_fp32 - r1_bf16).abs().max()}")

    r2_fp32 = sdp_attention(q, k, v)
    r2_fp16 = sdp_attention(q.half(), k.half(), v.half())
    r2_bf16 = sdp_attention(q.bfloat16(), k.bfloat16(), v.bfloat16())
    print(f"max diff r1_fp32 vs r2_fp32 = {(r1_fp32 - r2_fp32).abs().max()}")
    print(f"max diff r1_fp32 vs r2_fp16 = {(r1_fp32 - r2_fp16).abs().max()}")
    print(f"max diff r1_fp32 vs r2_bf16 = {(r1_fp32 - r2_bf16).abs().max()}")

    if HAS_FLASH:
        # r3 = flash_attention(q, k, v)
        r3_fp16 = flash_attention(q.half(), k.half(), v.half())
        r3_bf16 = flash_attention(q.bfloat16(), k.bfloat16(), v.bfloat16())
        print(f"max diff r1_fp32 vs r3_fp16 = {(r1_fp32 - r3_fp16).abs().max()}")
        print(f"max diff r1_fp32 vs r3_bf16 = {(r1_fp32 - r3_bf16).abs().max()}")

    # if HAS_FLASH_TRITON_CUDA:
    #     triton_cuda_fp16 = flash_attention_triton(q.half(), k.half(), v.half())
    #     print(f"max diff r1_fp32 vs triton_cuda_fp16 = {(r1_fp32 - triton_cuda_fp16).abs().max()}")
    #     # triton_cuda_bf16 = flash_attention_triton(q.bfloat16(), k.bfloat16(), v.bfloat16())
    #     # print(f"max diff r1_fp32 vs triton_cuda_bf16 = {(r1_fp32 - triton_cuda_bf16).abs().max()}")

    # if HAS_FLASH_TRITON_CUDA_DAO:
    #     triton_cuda_dao_fp16 = flash_attention_triton_dao(q.half(), k.half(), v.half())
    #     print(f"max diff r1_fp32 vs triton_cuda_dao_fp16 = {(r1_fp32 - triton_cuda_dao_fp16).abs().max()}")
    #     triton_cuda_dao_bf16 = flash_attention_triton_dao(q.bfloat16(), k.bfloat16(), v.bfloat16())
    #     print(f"max diff r1_fp32 vs triton_cuda_dao_bf16 = {(r1_fp32 - triton_cuda_dao_bf16).abs().max()}")

    if HAS_FLASH_TRITON_ROCM:
        triton_rocm_fp16 = flash_attention_triton_rocm(q.half(), k.half(), v.half())
        triton_rocm_bf16 = flash_attention_triton_rocm(q.bfloat16(), k.bfloat16(), v.bfloat16())
        print(f"max diff r1_fp32 vs triton_rocm_fp16 = {(r1_fp32 - triton_rocm_fp16).abs().max()}")
        print(f"max diff r1_fp32 vs triton_rocm_bf16 = {(r1_fp32 - triton_rocm_bf16).abs().max()}")

def main():
    # ViT base: embed_dim=768, num_heads=12, 
    #   patch=16, size=224: H=12, D=64, L=256 
    # ViT large: embed_dim=1024, num_heads=16:
    #   patch=14, size=224: H=16, D=64, L=256
    #   patch=16, size=384: H=16, D=64, L=576
    # ViT huge: embed_dim=1280, num_heads=16:
    #    patch=16, size=224: H=16, D=80, L=256
    #    patch=16, size=448: H=16, D=80, L=784

    batch_sizes = [64, 128]

    # Compare takes a list of measurements which we'll save in results.
    results = []

    dyptes = [torch.float16, torch.bfloat16]
    # heads = [12, 16]
    # dims = [64, 80, 128]
    # seq_lenghts = [256, 576, 784]

    # We replace 80 with 128 because the triton version does not support 80 
    HDL = [(12, 64, 256), (16, 64, 256), (16, 64, 576), (16, 128, 256), (16, 128, 784)]

    # HAS_FLASH_TRITON = False

    for dtype in dyptes:
        for H, D, L in HDL:
            for B in batch_sizes:
                q = torch.randn((B, H, L, D)).to('cuda').to(dtype)
                k = torch.randn((B, H, L, D)).to('cuda').to(dtype)
                v = torch.randn((B, H, L, D)).to('cuda').to(dtype)
                # label and sub_label are the rows
                # description is the column
                label = 'FlashAttention'
                sub_label = f'[{dtype}, {H}, {D}, {L}, {B}]'
                results.append(benchmark.Timer(
                    stmt='sdp_attention(q, k, v)',
                    setup='from __main__ import sdp_attention',
                    globals={'q': q, 'k': k, 'v': v},
                    # num_threads=num_threads,
                    label=label,
                    sub_label=sub_label,
                    description='math',
                ).blocked_autorange(min_run_time=1))
                results.append(benchmark.Timer(
                    stmt='scaled_dot_product_attention(q, k, v)',
                    setup='from torch.nn.functional import scaled_dot_product_attention',
                    globals={'q': q, 'k': k, 'v': v},
                    # num_threads=num_threads,
                    label=label,
                    sub_label=sub_label,
                    description='torch sdpa',
                ).blocked_autorange(min_run_time=1))
                if dtype != torch.float32:
                    if HAS_FLASH:
                        results.append(benchmark.Timer(
                            stmt='flash_attention(q, k, v)',
                            setup='from __main__ import flash_attention',
                            globals={'q': q, 'k': k, 'v': v},
                            # num_threads=num_threads,
                            label=label,
                            sub_label=sub_label,
                            description='flash',
                        ).blocked_autorange(min_run_time=1))
                    if HAS_FLASH_TRITON_CUDA and dtype != torch.bfloat16: # There's a bug with bfloat16
                        results.append(benchmark.Timer(
                            stmt='flash_attention_triton(q, k, v)',
                            setup='from __main__ import flash_attention_triton',
                            globals={'q': q, 'k': k, 'v': v},
                            # num_threads=num_threads,
                            label=label,
                            sub_label=sub_label,
                            description='triton-cuda',
                        ).blocked_autorange(min_run_time=1))
                    if HAS_FLASH_TRITON_CUDA_DAO:
                        results.append(benchmark.Timer(
                            stmt='flash_attention_triton_dao(q, k, v)',
                            setup='from __main__ import flash_attention_triton_dao',
                            globals={'q': q, 'k': k, 'v': v},
                            # num_threads=num_threads,
                            label=label,
                            sub_label=sub_label,
                            description='triton-cuda-dao',
                        ).blocked_autorange(min_run_time=1))
                    if HAS_FLASH_TRITON_ROCM:
                        results.append(benchmark.Timer(
                            stmt='flash_attention_triton_rocm(q, k, v)',
                            setup='from __main__ import flash_attention_triton_rocm',
                            globals={'q': q, 'k': k, 'v': v},
                            # num_threads=num_threads,
                            label=label,
                            sub_label=sub_label,
                            description='triton-rocm',
                        ).blocked_autorange(min_run_time=1))

    compare = benchmark.Compare(results)
    compare.print()

if __name__ == '__main__':
    # main()
    ex()
netw0rkf10w commented 5 months ago

I've tried training and obtained the following error:

Traceback (most recent call last):
  File "/home/user/code/ViT/main_finetune.py", line 212, in <module>
    main(args)
  File "/home/user/code/ViT/main_finetune.py", line 205, in main
    training_loop(args, model, device, data_loader_train, data_loader_val,
  File "/home/user/code/ViT/common.py", line 423, in training_loop
    train_stats = train_fn(args, model, device, data_loader_train,
  File "/home/user/code/ViT/engine_finetune.py", line 90, in train_epoch
    loss_scaler(loss, optimizer, clip_grad=max_norm,
  File "/home/user/code/ViT/utils/optim.py", line 48, in __call__
    self._scaler.scale(loss).backward(create_graph=create_graph)
  File "/home/user/.local/lib/python3.10/site-packages/torch/_tensor.py", line 522, in backward
    torch.autograd.backward(
  File "/home/user/.local/lib/python3.10/site-packages/torch/autograd/__init__.py", line 266, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/home/user/.local/lib/python3.10/site-packages/torch/autograd/function.py", line 289, in apply
    return user_fn(self, *args)
  File "/home/user/code/torching/src/torching/flash_attention/triton_rocm/flash_attention.py", line 846, in backward
    assert do.is_contiguous()
AssertionError

What I did is to create a wrapper function like this:

def attention(q, k, v, causal=False, sm_scale=None, sequence_parallel=False):
    if not sm_scale:
        sm_scale = q.shape[-1] ** -0.5
    metadata = MetaData(sm_scale=sm_scale)
    metadata.max_seqlens_q = q.shape[-2]
    metadata.max_seqlens_k = k.shape[-2]
    metadata.causal = causal
    q = q.contiguous()
    k = k.contiguous()
    v = v.contiguous()
    output, _ = _attention.apply(q, k, v, None, metadata)
    return output.contiguous()

I tried adding .contiguous() everywhere as above, but this didn't help...

Update: Adding the following helped:

if not do.is_contiguous():
    do = do.contiguous()
netw0rkf10w commented 5 months ago

Lol, launching a multi-node training I obtained again SystemError: <built-in function load_binary> returned NULL without setting an exception.

netw0rkf10w commented 5 months ago

What annoys me is that now it has numerical issues:

I forgot to mention that the original triton implementation and @tridao's triton implementation also suffer from numerical incorrectness.

netw0rkf10w commented 5 months ago

Some more updates. It seems that the triton implementation really shines in the backward pass. Below is the results for forward+backward:

                                       |   math  |  torch  |  flash  |  triton-rocm
  [torch.float16, 12, 64, 256, 64]     |    3.0  |    2.8  |    1.3  |       1.1   
  [torch.float16, 12, 64, 256, 128]    |    5.9  |    5.5  |    2.6  |       1.3   
  [torch.float16, 16, 64, 256, 64]     |    4.0  |    3.7  |    1.8  |       1.2   
  [torch.float16, 16, 64, 256, 128]    |    7.9  |    7.4  |    3.5  |       1.7   
  [torch.float16, 16, 64, 576, 64]     |   16.7  |   14.8  |    8.2  |       3.6   
  [torch.float16, 16, 64, 576, 128]    |   33.4  |   29.4  |   16.2  |       7.0   
  [torch.float16, 16, 128, 256, 64]    |    7.5  |    7.5  |    3.5  |       2.5   
  [torch.float16, 16, 128, 256, 128]   |   14.9  |   14.9  |    6.9  |       3.7   
  [torch.float16, 16, 128, 784, 64]    |   54.6  |   51.5  |   26.7  |      14.6   
  [torch.float16, 16, 128, 784, 128]   |  109.1  |  102.9  |   52.9  |      27.2   
  [torch.bfloat16, 12, 64, 256, 64]    |    3.4  |    3.2  |    1.6  |       1.3   
  [torch.bfloat16, 12, 64, 256, 128]   |    6.7  |    6.3  |    3.1  |       1.9   
  [torch.bfloat16, 16, 64, 256, 64]    |    4.5  |    4.2  |    2.1  |       1.6   
  [torch.bfloat16, 16, 64, 256, 128]   |    8.9  |    8.4  |    4.1  |       2.5   
  [torch.bfloat16, 16, 64, 576, 64]    |   16.6  |   14.6  |   10.2  |       5.0   
  [torch.bfloat16, 16, 64, 576, 128]   |   33.3  |   29.1  |   20.2  |       9.8   
  [torch.bfloat16, 16, 128, 256, 64]   |    7.5  |    7.5  |    4.3  |       3.0   
  [torch.bfloat16, 16, 128, 256, 128]  |   15.0  |   15.0  |    8.5  |       4.7   
  [torch.bfloat16, 16, 128, 784, 64]   |   57.5  |   54.2  |   34.6  |      17.9   
  [torch.bfloat16, 16, 128, 784, 128]  |  114.7  |  108.2  |   68.7  |      33.6   

Times are in milliseconds (ms).

Triton is often twice as fast as flash-attn!!!

Unfortunately I was not able to use it for training because of the above SystemError.

netw0rkf10w commented 5 months ago

Unfortunately I was not able to use it for training because of the above SystemError.

It's finally working after I deleted ~/.triton! Apparently it was some cache issue.

Unfortunately though, it didn't learn anything (the flat curve in the figure below). I officially gave up!

Screenshot 2024-02-20 at 12 19 13
zhanglx13 commented 3 months ago

Now we are migrating to upstream. @netw0rkf10w Feel free to open new tickets regarding issues on AMD backend at https://github.com/openai/triton/issues