triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.54k stars 1.67k forks source link

[Persistent] Performance on 09-persistent-matmul on A100 worse than non-persistent #5174

Open CharlieFRuan opened 1 week ago

CharlieFRuan commented 1 week ago

Describe the issue

With A100, I ran two benchmarks with the provided 09-persistent-matmul.py on https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html

With M=N=4K, I ran

python 09-persistent-matmul.py --prec fp16 --K_range 128 4096 --K_step 128

With M=N=8K, I ran

python 09-persistent-matmul.py --prec fp16 --K_range 256 8192 --K_step 256

I did not change anything from the script, except for M and N when running for 4K.

I have the plots below, which show that persistent matmul is only better than non-persistent in very small Ks, different from the results in the Pipelining Persistent Kernels talk.

4K_8K_plot

I am expecting persistent matmul to perform better across all Ks due to the benefit of pipelining across different output tiles.

I wonder if my obvservation is expected, and if I have missed anything (e.g. should I do things like num_programs = NUM_SM * occupancy as in the softmax tutorial). Thank you!

What I've tried

My script can be found here: https://github.com/CharlieFRuan/scratch/blob/03a57252aef6681b9f135a99c0a98c94fd2584f9/triton_workspace/persitent_matmul_bench/plot.ipynb

I also checked the TTGIR and searched for async -- they all seem to be at the right places.

For the hyperparameters, I just used the provided ones in the script:

BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 256, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None

I also ran autotuning on the persistent matmul kernel with M=N=K=8192 and M=N=K=4096 with a comprehensive list of configs (shown below) -- the one provided is the best one according to the autotuner.

for block_size_m in [256, 128, 64, 32]:
    for block_size_n in [256, 128, 64, 32]:
        for block_size_k in [64, 32, 16]:
            for num_stages in [3, 4]:
                for num_warps in [4, 8]:

Environment details

Triton: main HEAD from yesterday at https://github.com/triton-lang/triton/commit/5ebd1e5a6630877e73cad547fd9495a5441c78be

GPU: A100

CharlieFRuan commented 6 days ago

I did some more benchmarking on A100. I tuned both persistent and non-persistent kernel with the following search space using a MNK=(4K,4K,2K) problem, (no NUM_SMs for non-persistent):

def get_configs():
    num_sms = torch.cuda.get_device_properties("cuda").multi_processor_count
    configs = []
    for bs_m in [32, 64, 128, 256]:
        for bs_n in [64, 128, 256]:
            for bs_k in [32, 64, 128]:
                    if bs_m == 256 and bs_n == 256 and bs_k == 128:
                        continue
                    for num_stages in [3, 4, 5]:
                        for num_warps in [2, 4, 8]:
                            for num_sm_multiple in [1, 2, 3]:
                                configs.append(triton.Config({
                                    "BLOCK_SIZE_M": bs_m, "BLOCK_SIZE_N": bs_n, "BLOCK_SIZE_K": bs_k,
                                    "GROUP_SIZE_M": 1, "NUM_SMS": num_sms * num_sm_multiple
                                }, num_stages=num_stages, num_warps=num_warps))
    return configs

Persistent's best config uses 2 * the SM count (108 for A100):

{
  "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8, "num_stages": 3,
  "num_warps": 4, "NUM_SMS": 216
}

However, the performance of persistent matmul is still quite bad (only better in very small Ks) (ran with the same script as https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html)

image

We were able to reproduce the fp8 example on H100, where persistent matmul indeed performs better, but not in fp16.

Therefore, my question is, why persistent matmul is not performing better despite the benefits like reduced thread block launching overhead? Should I see a performance boost other than H100 fp8? Thanks! cc @pawelszczerbuk