Open CharlieFRuan opened 1 week ago
I did some more benchmarking on A100. I tuned both persistent and non-persistent kernel with the following search space using a MNK=(4K,4K,2K) problem, (no NUM_SMs
for non-persistent):
def get_configs():
num_sms = torch.cuda.get_device_properties("cuda").multi_processor_count
configs = []
for bs_m in [32, 64, 128, 256]:
for bs_n in [64, 128, 256]:
for bs_k in [32, 64, 128]:
if bs_m == 256 and bs_n == 256 and bs_k == 128:
continue
for num_stages in [3, 4, 5]:
for num_warps in [2, 4, 8]:
for num_sm_multiple in [1, 2, 3]:
configs.append(triton.Config({
"BLOCK_SIZE_M": bs_m, "BLOCK_SIZE_N": bs_n, "BLOCK_SIZE_K": bs_k,
"GROUP_SIZE_M": 1, "NUM_SMS": num_sms * num_sm_multiple
}, num_stages=num_stages, num_warps=num_warps))
return configs
Persistent's best config uses 2 * the SM count (108 for A100):
{
"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8, "num_stages": 3,
"num_warps": 4, "NUM_SMS": 216
}
However, the performance of persistent matmul is still quite bad (only better in very small Ks) (ran with the same script as https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html)
We were able to reproduce the fp8 example on H100, where persistent matmul indeed performs better, but not in fp16.
Therefore, my question is, why persistent matmul is not performing better despite the benefits like reduced thread block launching overhead? Should I see a performance boost other than H100 fp8? Thanks! cc @pawelszczerbuk
Describe the issue
With A100, I ran two benchmarks with the provided
09-persistent-matmul.py
on https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.htmlWith
M=N=4K
, I ranWith
M=N=8K
, I ranI did not change anything from the script, except for M and N when running for 4K.
I have the plots below, which show that persistent matmul is only better than non-persistent in very small Ks, different from the results in the Pipelining Persistent Kernels talk.
I am expecting persistent matmul to perform better across all Ks due to the benefit of pipelining across different output tiles.
I wonder if my obvservation is expected, and if I have missed anything (e.g. should I do things like
num_programs = NUM_SM * occupancy
as in the softmax tutorial). Thank you!What I've tried
My script can be found here: https://github.com/CharlieFRuan/scratch/blob/03a57252aef6681b9f135a99c0a98c94fd2584f9/triton_workspace/persitent_matmul_bench/plot.ipynb
I also checked the TTGIR and searched for
async
-- they all seem to be at the right places.For the hyperparameters, I just used the provided ones in the script:
I also ran autotuning on the persistent matmul kernel with M=N=K=8192 and M=N=K=4096 with a comprehensive list of configs (shown below) -- the one provided is the best one according to the autotuner.
Environment details
Triton: main HEAD from yesterday at https://github.com/triton-lang/triton/commit/5ebd1e5a6630877e73cad547fd9495a5441c78be
GPU: A100