Open jeromeku opened 3 months ago
Great catch!
I've tested both formulations (calculating with grid_m and grid_n) - both formulations appear to produce the same column-major movement. See the blog post with the PyTorch code snippet to see that indeed the output with this formulation produces pid_m, pid_n
= (0, 0), (1, 0), (2, 0) etc.
I think there's something wrong with current code. When I simply swap the calculation of fused_moe_v0
and fused_moe_v2
(I replaced v1 with v2) in test_moe_gemm.py
, the allclose assertions would fail.
Changing pid_m = (pid % grid_n)
to pid_m = (pid % grid_m)
would make the test pass, but the performance will be almost the same with current vllm implementation.
Hey, thanks for figuring this out! Yes indeed. With this fix it appears that the speedup from the column-major optimization is gone unfortunately.
I'm still not clear why this is able to pass the test cases we throw at it -- so will need to dig a bit deeper on that end!
From my test result, the address of intermediate_cache are always the same between the adjacent v0 and v2 runs. The grid_m
problem will cause some block calculations to be skipped, but we still get the right value from last v0 run.
I don't know how torch allocator works and how to avoid the problem. I tried empty_cache
between the runs but no luck.
If certain block calculations are skipped, ofc it gives a speedup, doesn't it?
I also come across this problem when I are reading the code today. I am not sure why we apply certain ordering to launch the threadblock, it seems to be contradictive to CUDA programming
The schedule determines the cache re-use pattern of your algorithm. The ordering is not unique and should be optimized for the type of problem sizes you are working with. A similar thing is done in CUDA.
CUDA Link: https://developer.nvidia.com/blog/optimizing-compute-shaders-for-l2-locality-using-thread-group-id-swizzling/ Triton Link: https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html#l2-cache-optimizations
@AdnanHoque @lessw2020
Thanks for the great blogpost and kernels.
In the column-major ordering:
Why is
pid_m = (pid % grid_n)
?grid_m
is the leading dimension (number ofblock
rows), so should it bepid_m = pid % grid_m
? Apologies if I'm misunderstanding the issue.