triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.41k stars 1.64k forks source link

Matmul performance issue on MI300X #4959

Open remi-or opened 3 weeks ago

remi-or commented 3 weeks ago

Hello, I am experiencing performance issues when running triton on an AMD GPU, the MI300X. When running the script 03-matrix-multiplication.py, I get this ouput:

triton_output_with_fp16_inputs=tensor([[ -3.6055,  -1.8945,  17.0469,  ..., -18.2031,  16.1094, -24.1562],
        [ 23.1094,   2.1504,  -6.7305,  ...,   8.1016,   2.5625,  34.7500],
        [ -4.4062,  -2.4609,  -5.4219,  ..., -30.2031, -21.9219, -19.7188],
        ...,
        [-19.1719,  40.2500,  32.3750,  ...,  36.9688, -56.8125, -23.5625],
        [ 41.6250, -11.1797, -18.5156,  ..., -14.6250, -18.6406,  -1.9668],
        [ 27.9844,  -8.1797, -14.2031,  ...,   1.6094,   0.4919,  13.1953]],
       device='cuda:0', dtype=torch.float16)
torch_output_with_fp16_inputs=tensor([[ -3.6055,  -1.8945,  17.0469,  ..., -18.2031,  16.1094, -24.1562],
        [ 23.1094,   2.1504,  -6.7305,  ...,   8.1016,   2.5625,  34.7500],
        [ -4.4062,  -2.4609,  -5.4219,  ..., -30.2031, -21.9219, -19.7188],
        ...,
        [-19.1719,  40.2500,  32.3750,  ...,  36.9688, -56.8125, -23.5625],
        [ 41.6250, -11.1797, -18.5156,  ..., -14.6250, -18.6406,  -1.9668],
        [ 27.9844,  -8.1797, -14.2031,  ...,   1.6094,   0.4919,  13.1953]],
       device='cuda:0', dtype=torch.float16)
❌ Triton and Torch differ
matmul-performance-fp16:
         M       N       K     rocBLAS      Triton
0    256.0   256.0   256.0    3.310748    2.326614
1    384.0   384.0   384.0   10.355359    6.483809
2    512.0   512.0   512.0   22.042655   12.619786
3    640.0   640.0   640.0   29.411421   22.258035
4    768.0   768.0   768.0   45.140492   33.804839
5    896.0   896.0   896.0   79.277365   48.268620
6   1024.0  1024.0  1024.0  105.320433   62.845209
7   1152.0  1152.0  1152.0   99.513362   84.061354
8   1280.0  1280.0  1280.0  113.682181  104.910050
9   1408.0  1408.0  1408.0  140.907615  129.272171
10  1536.0  1536.0  1536.0  161.107814  149.771803
11  1664.0  1664.0  1664.0  273.195548  171.407849
12  1792.0  1792.0  1792.0  303.064305  184.877363
13  1920.0  1920.0  1920.0  345.422917  206.166081
14  2048.0  2048.0  2048.0  364.985534  223.946983
15  2176.0  2176.0  2176.0  419.225621  241.388536
16  2304.0  2304.0  2304.0  295.986112  257.754715
17  2432.0  2432.0  2432.0  339.710577  276.742157
18  2560.0  2560.0  2560.0  357.799432  290.082579
19  2688.0  2688.0  2688.0  404.353894  301.317556
20  2816.0  2816.0  2816.0  430.948811  279.204214
21  2944.0  2944.0  2944.0  469.727161  285.181540
22  3072.0  3072.0  3072.0  469.931750  316.577164
23  3200.0  3200.0  3200.0  382.770206  276.719371
24  3328.0  3328.0  3328.0  409.942200  284.030899
25  3456.0  3456.0  3456.0  429.789195  288.956066
26  3584.0  3584.0  3584.0  507.710863  269.829437
27  3712.0  3712.0  3712.0  519.121058  300.896937
28  3840.0  3840.0  3840.0  508.807076  286.446176
29  3968.0  3968.0  3968.0  582.421939  334.423096
30  4096.0  4096.0  4096.0  610.690529  320.042279

The part about torch and triton not matching is not that worrying to me, I believe it has something to do with denormals, but the performance issue is a big problem. When running on an older AMD GPU, the MI210, these issues were not present. I have also tried building triton from source and with another version (3.0.0) but the performances were still not close to rocBLAS. I also tried adding matrix_instr_nonkdim = 16 and kpack = 2 in the config kwargs, but it did not help (when using triton 3.1). Any idea on how to fix this please? Thanks!

python version: 3.10.12
triton version: 3.1.0
device: AMD Instinct MI300X
rocm version: 6.2.1
antiagainst commented 3 weeks ago

Note that the tutorial is just meant to be a tutorial for getting started with Triton programming. It's not meant to be a reference performance kernel. If you'd like to what changes needed for the performance, you can refer to https://github.com/triton-lang/triton/pull/4863. Also you can checkout some downstream kernels we have https://github.com/ROCm/triton/tree/main_perf/python/perf-kernels.

remi-or commented 3 weeks ago

Hi @antiagainst , thanks for the links. I have tried them out and it seems that they still don't provide rocBLAS-like performances on the MI300X. This might be a little out-of-scope for this issue, but do you or anyone else have knowledge of someone replicating rocBLAS performances on such GPUs? I am asking because it seems odd to me that I could get such performances on older GPUs but not this one. Thanks!

zhanglx13 commented 3 weeks ago

Triton doesn't provide rocBLAS-like performance for these gemm sizes. For some gemm sizes, triton can get on par with rocBLAS, but it needs more advanced compiler changes, which is not included in the main branch yet. For other gemm sizes, triton performance is usually limited by "tile sizes have to be power of 2".

I am asking because it seems odd to me that I could get such performances on older GPUs but not this one

Chances are rocBLAS does not provide tuned configs for MI200 GPUs. Can you post the numbers for both MI300 and MI200?