triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
12.42k stars 1.5k forks source link

How to set proper candidates for tuning a triton kernel? #4257

Open sleepwalker2017 opened 1 month ago

sleepwalker2017 commented 1 month ago

I see the matmul kernel example here: https://github.com/triton-lang/triton/blob/main/python/tutorials/03-matrix-multiplication.py

You set 8 candidates for fp16 gemm.

I wonder how do you find the 8 sets of parameters?

Is there any other parameters more efficient for some specific shape gemm ?

How can I search in a large search space to get a better kernel? Or how can I set the values of the parameters?

Thank you!

[
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3,
                      num_warps=8),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
                      num_warps=2),
        triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
                      num_warps=2),
dengl11 commented 1 month ago

(I am not from the triton team, so just sharing my two cents)

I think you can just generate all combinations of the autotuning parameters in python with

import itertools
itertools.product(...)

to search for a large space (blindly)

To hand-tune for the best parameters (instead of relying on the autotuning process to search blindly), I think you can generate the nsight computer profiling for running a triton kernel, and see how a certain change of the parameter could affect the lower level GPU performance metrics (on the nsight compute app); which needs the domain knowledge of GPU performance optimization (which is actually what triton aims to free you from via compiler optimization) and can be time consuming.

sleepwalker2017 commented 1 month ago

(I am not from the triton team, so just sharing my two cents)

I think you can just generate all combinations of the autotuning parameters in python with

import itertools
itertools.product(...)

to search for a large space (blindly)

To hand-tune for the best parameters (instead of relying on the autotuning process to search blindly), I think you can generate the nsight computer profiling for running a triton kernel, and see how a certain change of the parameter could affect the lower level GPU performance metrics (on the nsight compute app); which needs the domain knowledge of GPU performance optimization (which is actually what triton aims to free you from via compiler optimization) and can be time consuming.

Thank you! I'm now trying the blind search and it really gives fast kernels.except that the search space is so large and it takes a lot of time. But that's ok given the fact that it saves me a lot of time to tune it by hand.

BTW, are you familiar with tuning kernels using nsight compute? I used it, but I can't get an effective advice from it. because the metrics are so many. Do you have any advice for that, or some tutorials? Thank you!

dengl11 commented 1 month ago

yeah autotuning (blind search) is very powerful in my experiences as well; I do not have concrete stats, but I think it is slow primarily due to repeated compilations (for each shape being profiled here, a JIT compilation will be triggered)

I am also a beginner in nsight compute profiling; I have enjoyed this series of nvidia tutorials on youtube like:

sleepwalker2017 commented 1 month ago

Thank you for the sharing!

I tried triton searching, and it's really powerful, I generate gemm with triton and it's faster than cublas.