ROCm / aotriton

Ahead of Time (AOT) Triton Math Library
MIT License
32 stars 13 forks source link

[Documentation]: The overall tuning idea of ​​aotriton #38

Open hubotao1 opened 1 month ago

hubotao1 commented 1 month ago

Description of errors

Hello, author. How should I call the best kernel after aotriton is compiled successfully? I see that you set the autotune parameter to false by default in the implementation of attn_torch_function.py. Does that mean the tuning process of block_m and block_n parameters is cancelled? If it is cancelled, what is the meaning of the generated kernel? What does the libaotriton_v2.a library generated at last contain? Can you tell me in detail? I call the libaotriton_v2.a library on the pytoch side to test the performance of the fa operator. I set autotune to True. There seems to be no process of finding the optimal kernel on the pytorch side. Where is the process of finding the optimal kernel and calling it implemented? How to implement it? I found that the content of the tuning_database.sqlite3 database did not change before and after compiling aotriton. What role does it play in the overall tuning process? Thank you very much for your answer!

Attach any links, screenshots, or additional evidence you think will be helpful.

No response

xinyazhang commented 1 month ago

Does that mean the tuning process of block_m and block_n parameters is cancelled?

Yes. In fact, not only the attn_torch_function.py file, most supportive files under tritonsrc/ only have one goal: confirm the Triton kernel is bug free under JIT Triton framekwork. Thus the tuning parameters is very conservative because extreme autotune configs (either from tuning parameters or due to compiler options) may result with Triton kernel that does not give correct result, causes GPU segfault, or worse, triggers a GPU reset.

To avoid these problems by using extreme configs, we developed a more sophisticated tuning system (but under refactoring due to design problems). However either the old tritonsrc/tune_flash.py or the upcoming test/tune_flash.py will generate/update the tuning database v2python/rules/tuning_database.sqlite3. The build system will use this database to actually build the library.

What does the libaotriton_v2.a library generated at last contain? Can you tell me in detail? There seems to be no process of finding the optimal kernel on the pytorch side. Where is the process of finding the optimal kernel and calling it implemented? How to implement it?

The optimal kernel lookup in AOTriton is done in pure C++ by generated code, if you built AOTriton from source, this process can be in found the following file(s) (using attn_fwd API as an example):

  1. build/v2src/flash/shim.attn_fwd.cc. This part of the dispatcher will select a group of kernel that have the same functionality that is hard coded into the kernel (GPU archtecture, CAUSAL=True or False, dtype is fp16/bf16/fp32, etc.) and theoretically can be used interchangeably.
  2. Files under build/v2src/flash/autotune.attn_fwd (for example FONLY__^bf16@16,False,128,False,False,False,0___MI300X.cc). This will eventually find the optimal kernel by embedding the tuning database as a lookup table (lut) in C++ form.

What role does it play in the overall tuning process?

As I described above, the tuning is not done during build for obvious reasons: you need corresponding GPUs installed and a clean environment to tune any Triton kernel, which is impractical for a build node. Building process only use the results stored in the tuning database.

hubotao1 commented 1 month ago

Thank you very much for your reply!

  1. I have successfully compiled aotriton, but when I call the generated libaotriton_v2.a library on the pytorch side to test the performance of the fa operator, the performance has degraded. What is the reason? It seems that it did not find the kernel with the optimal parameter configuration to run during the kernel running process (because the running process is very fast).
  2. Can you explain in detail what role the tuning_database.sqlite3 database plays in the entire aotriton building process? Because I found that the content inside does not change whether it is before or after aotriton is successfully built or before or after tritonsrc/tune_flash.py is run?
  3. Can you explain in detail the use process of your aotriton library? After compiling aotriton successfully, what else do I need to do to call the best tuning kernel?

I would be grateful if you could take the time to help answer these questions.

xinyazhang commented 1 month ago
  1. the performance has degraded. What is the reason

It's expected to have worse performance on backward kernels. An upcoming release will fix the problem and matches the math backend's performance. Further improvements require additional works on Triton compiler optimizations.

2. Because I found that the content inside does not change whether it is before or after aotriton is successfully built or before or after tritonsrc/tune_flash.py is run?

If the optimal kernel for configurations stays the same, certainly the tuning database will not change even if tune_flash.py is run again.

3. Can you explain in detail the use process of your aotriton library?

Just use it as conventional C++ library. You don't need to do anything to specify the optimal kernel (otherwise there is no point of including a tuning database during the build)

The upcoming release will have options to let you select kernels manually but which is reserved to generate the tuning database. Normally you don't need to care about which kernel you eventually called since the dispatcher will select the optimal one according to the tuning database used during the compiling.

hubotao1 commented 1 month ago

Thank you very much for your reply!

I have the following questions and look forward to your response:

1: When I only use the default tuning data sliders 'BLOCK_M': 128, 'BLOCK_N': 64 in your code to generate the static library libaotriton_v2.a, and call the libaotriton_v2.a static library to test the forward and reverse performance, and then compare it with the forward and reverse performance of Flash V2 on the torch side, I find that the forward performance generated by the libaotriton_v2.a static library is slightly higher than the forward performance of Flash2 on the torch side, but the reverse performance is much worse. So I added several sets of BLOCK_M and BLOCK_N sliders in an attempt to find the best sliders to improve the performance in the reverse direction. I verified that the added 'BLOCK_M': 64, 'BLOCK_N': 32 sliders have higher performance through aotriton/tritonsrc/performance_forward.py. I found that the added sliders were written into the database through tritonsrc/tune_flash.py. After the compilation was successful, I also found 'BLOCK_M': 64, 'BLOCK_N': in build/v2src/flash/autotune.bwd_kernel_dk_dv: 32 corresponds to /public/home/zhangqha/test_code/aotriton/aotriton/build/v2src/flash/gpu_kernel_image.bwd_kernel_dk_dv/bwd_kernel_dk_dv-Sig-F^bf16@16_16_False_False_False_1P64_32CO__warp4_stg1_wave0-Gpu-K100_AI.hsaco. This set of 'BLOCK_M': 64, 'BLOCK_N': 32 sliders does exist; when testing the reverse performance of the libaotriton_v2.a static library, it was found that its reverse performance still did not improve? Can you explain in detail why?

2: I found that no matter how I adjust the values ​​of the BLOCK_M and BLOCK_N sliders, the reverse results of calling the libaotriton_v2.a static library cannot surpass it compared with Flash V2. Can you explain the reason in detail? Finally, after testing, the results were stable at libaotriton_v2.a static library bwd: 3.48 TFLOPs/s, Flash V2bwd: 4.72 TFLOPs/s. Are the above results normal? Is there any other way to improve the reverse performance of libaotriton_v2.a static library bwd?

3: How can I find out which group of BLOCK_M and BLOCK_N slider kernels are called when I use causal=False, nheads=64, headdim=64, batch_size=2, seqlen=2048 to test libaotriton_v2.a? Because I found that no matter whether I added better performance BLOCK_M and BLOCK_N sliders to the database, and the better performance BLOCK_M and BLOCK_N sliders were also present in the hsaco file after successful compilation, the reverse operator performance basically did not change when I tested the libaotriton_v2.a static library. I suspect that the libaotriton_v2.a static library did not use the kernel corresponding to the better performance BLOCK_M and BLOCK_N sliders I added when I called it.

4: When I write my customized n_heads, seqlen_q and seqlen_k into the database through tritonsrc/tune_flash.py, for example, n_heads 5 8 10 20 32 64 --d_head 16 32 64 128 --seqlen_q 64 128 256 512 1024 2048 4096 --seqlen_k 64 128 256 512 1024 2048 4096 --causal 0 1 --dropout_p 0.0 --dtype float16 bfloat16 --bias_type 0 1. I found that in the multiple .cc codes of build/v2src/flash/autotune.attn_fwd generated after successful compilation, it only selects the corresponding kernel through the length index of seqlen_q and seqlen_k. What is the selection method of n_heads and d_head parameters? When it selects the corresponding kernel through seqlen_q and seqlen_k, how does it determine which n_heads and d_head values ​​should be selected?

5: I would like to ask, how do you judge that when using the generated static library libaotriton_v2.a, the performance of the kernel it finally selects is the best? Where is the specific implementation code?

I would be grateful if you could take the time to help answer these questions.

xinyazhang commented 1 month ago

This set of 'BLOCK_M': 64, 'BLOCK_N': 32 sliders does exist; when testing the reverse performance of the libaotriton_v2.a static library, it was found that its reverse performance still did not improve? Can you explain in detail why?

The backward performance is a known problem. The best performance is just on-par with Math backend and we have already enumerate (64,32) configurations in the database (see: https://github.com/ROCm/aotriton/pull/39)

Can you explain the reason in detail?

The reason is sophisticated but the short answer is the Triton compiler lacks some optimizations employed by cutlass/CK.

How can I find out which group of BLOCK_M and BLOCK_N slider kernels are called when I use causal=False, nheads=64, headdim=64, batch_size=2, seqlen=2048 to test libaotriton_v2.a?

Compile AOTriton with -DCMAKE_BUILD_TYPE=Debug, and AOTriton will print the configuration. However, be aware BLOCK_M/N are not the only factor determining the performance, waves_per_eu and num_warps are also important.

What is the selection method of n_heads and d_head parameters?

D_HEAD is determined In Step 1. Such arguments are referred as "functionals" because you cannot use D_HEAD=128 kernel on D_HEAD=64 inputs.

N_HEADS is not considered for tuning due to its low impact on performance: each head number directly translates to GPU block number, and each GPU block is processed independently.

how do you judge that when using the generated static library libaotriton_v2.a, the performance of the kernel it finally selects is the best? Where is the specific implementation code?

For now it's determined by experimental results in https://github.com/ROCm/aotriton/blob/b5f89972a933afc148dad25af94d37e79c967765/test/mptune/core/cpp_autotune.py#L12

(largely copied from Triton but included a validation step)

hubotao1 commented 1 month ago

Thank you very much for your reply!

I have the following questions and look forward to your reply:

1: The commit version I use is https://github.com/ROCm/aotriton/commit/04b5df8c8123f90cba3ede7e971e6fbc6040d506 Date: Mon Jun 3 15:21:28 2024 -0500; When I use the parameters hq=hk=8,sq=sk=1024,dim=80,causal=false to test the performance through tritonsrc/performance_forward.py, I found that the best kernel for its forward kernel call is BLOCK_M: 64, BLOCK_N: 64. After successful compilation, call libaotriton_v2.a and test the performance with the same parameters. I found that the best slider of the kernel corresponding to sq=sk=1024 in the implementation of aotriton build\v2src\flash\autotune.attn_fwd\FONLY__^bf16@16,False,128,False,False,True,0___K100_AI.cc is BLOCK_M = 128 .BLOCK_N = 64; but in fact, the test results of sq=sk=1024 show that the test performance of BLOCK_M = 128 .BLOCK_N = 64 is not as good as BLOCK_M: 64 and BLOCK_N: 64. The same is true for printing the best kernel when the reverse kernel is called. The best kernels called by the two are not the same. What is the reason for this? Is it because the commit version I use is too old and the function is not perfect? ​​Is there any way to solve this problem?

2: Below are the steps I took to use the aotriton library you made. I always feel that some steps are missing. I look forward to your corrections: (1) Add the parameters I want to TRITON_CONFIG_LIST_FWD+BWD in tritonsrc/attn_torch_function.py, and change autotune=False to autotune=True. (2) Then test its corresponding accuracy and performance through tritonsrc/test_forward.py, tritonsrc/test_backward.pytritonsrcperformance_forward.py to determine which kernel it uses best when using different input parameters, so as to determine whether the kernel it calls after successful compilation is the same as the above. (3) Run tritonsrc/tune_flash.py, add the other parameter fields I need in "add_argument" to the database, such as (seq=77, nhead=5,10), and make sure they are correctly written to the database. (4) Compile aotriton. After success, I also found the tuning parameter items in TRITON_CONFIG_LIST that I added to the database in the build/v2src/flash/autotune.attn_fwd directory. (5) Call build/install_dir to verify whether the performance of the best kernel called by the forward and reverse operators in aotriton is consistent with the performance of the forward and reverse operators on the torch side.

Have I missed some important steps in the above steps or are some steps wrong? I look forward to your corrections.

I would be grateful if you could take the time to help answer these questions.