Open Edenzzzz opened 2 months ago
Another thing to note with this, is that this hardcoded num_warps
of 32 causes errors when trying to run on AMD Instinct accelerators, since they have a warp size of 64 rather than 32, and this ends up exceeding the max threads per block:
https://github.com/triton-lang/triton/issues/4128#issuecomment-2200944647 https://github.com/linkedin/Liger-Kernel/issues/231
Yeah I think this needs some more adaptive tuning. Triton seems to give reasonable performance with brute-force searching/hardcoding, so folks don't quite care?
32 is the best tuned value i found for llama. did you observe better performance number if you use a different number?
The hardcoded num_warps
of 32 causes this error on AMD Instinct accelerators, which I managed to avoid on my end by reducing it to 16 instead. I was subsequently able to full finetune Qwen2.5 72B at 32k ctx length and Mistral Large 123B at 16k ctx on 8xMI300X with all Liger Kernel optimizations enabled.
Perhaps it would be possible to automatically detect when an AMD Instinct accelerator is in use and adjust num_warps
to a suitable value accordingly? However I'm not sure if this reduced value of 16 is optimal performance-wise on these GPUs.
It seems the vocab size(32k or 64k at max) will be much larger than any number of threads you have, so better set num_warps
to the max available on your hardware?
makes sense. feel free to send a pr for AMD!
It seems the the block size/vocab size(32k or 64k at max) will be much larger than any number of threads you have, so better set
num_warps
to the max available on your hardware?
@Edenzzzz
I found that in ROCm/vllm and vllm fork, the restrict the num_warps
to be either 4
or 8
in actual kernel, so maybe these numbers can be a reference to tune the num_warps
on AMD.
https://github.com/search?q=repo%3AROCm%2Fvllm+num_warp&type=code&p=3
I have opened a PR https://github.com/linkedin/Liger-Kernel/pull/326 to address the support of Liger Kernel on AMD GPU.
After fixing the num_warps
all test case in the convergence tests passed. However, there are unit tests failing due to numerical precision. I wonder if there is a guideline how to debug this? The failure logs are in the PR description. @ByronHsu
I wonder why you use 32 (256 threads/block instance) here instead of deciding based on hidden size? Thanks. https://github.com/linkedin/Liger-Kernel/blob/dd86cbd2092177681acf75643ded1b23a785a816/src/liger_kernel/ops/fused_linear_cross_entropy.py#L95