for i in range(n):
statement1(i)
statement2(i)
statement3(i)
statement4(i)
into
statement1(0)
statement2(0)
for i in range(n):
statement3(i)
statement4(i)
statement1(i+1)
statement2(i+1)
In the matmul kernel, both the cp.async and the ld.matrix are circular/double buffered. This PR applies loop rotation to the matmul main loop to pull the first iteration's ld.matrix out of the main loop of cp.async.
That is, to change the code from
cp.async prologue
// main loop for cp.async
for (...) {
cp.async
ld.matrix prologue // <-- to be rotated
// main loop for ld.matrix
for (...) {
ld.matrix
mma
}
mma // epilogue for ld.matrix
}
to
cp.async prologue
ld.matrix prologue // <-- rotated
// main loop for cp.async
for (...) {
cp.async
// main loop for ld.matrix
for (...) {
ld.matrix
mma
}
mma // epilogue for ld.matrix
ld.matrix prologue // <-- rotated
}
In order to do so, I need to do a reorder to change the matmul schedule from
Because in the first schedule, the loop structure is
for blockIdx.x:
for blockIdx.y:
cp.async
for i1: # cp.async circular buffer loop
cp.async
for threadIdx.z:
for threadIdx.y:
ld.matrix
for i2: # ld.matrix double buffer loop
ld.matrix
mma
mma
where inside the cp.async circular buffer loop, the entire ld.matrix->mma is contained in the threadIdx trivial loop, and the ld.matrix is not separable.
In contrast, for the second schedule, we have
for blockIdx.x:
for blockIdx.y:
cp.async
for i1: # cp.async circular buffer loop
cp.async
ld.matrix
for i2: # ld.matrix double buffer loop
for threadIdx.z:
for threadIdx.y:
ld.matrix
mma
for threadIdx.z:
for threadIdx.y:
mma
The blockIdx and threadIdx loops are trivial loops, so this schedule change actually doesn't affect the generated CUDA kernel. However, it does make kernel IR easier to deal with.
---------------------------------------------------------------------------------------------------------------------------------
Benchmark Time CPU Iterations
---------------------------------------------------------------------------------------------------------------------------------
Nvfuser_Matmul_4warp3stage/no_quant_nvfuser_4warp_TT_Legacy/2048/3456/4096/manual_time 990 us 2746 us 711
Nvfuser_Matmul_4warp3stage/no_quant_nvfuser_4warp_TN_Legacy/2048/3456/4096/manual_time 871 us 2629 us 720
Nvfuser_Matmul_4warp3stage/no_quant_nvfuser_4warp_NT_Legacy/2048/3456/4096/manual_time 1064 us 2821 us 579
Nvfuser_Matmul_4warp4stage/no_quant_nvfuser_4warp_TT_Legacy/2048/3456/4096/manual_time 1278 us 3034 us 499
Nvfuser_Matmul_4warp4stage/no_quant_nvfuser_4warp_TN_Legacy/2048/3456/4096/manual_time 1159 us 2914 us 607
Nvfuser_Matmul_4warp4stage/no_quant_nvfuser_4warp_NT_Legacy/2048/3456/4096/manual_time 1432 us 3188 us 447
Nvfuser_Matmul_8warp3stage/no_quant_nvfuser_8warp_TT_Legacy/2048/3456/4096/manual_time 1209 us 2966 us 526
Nvfuser_Matmul_8warp3stage/no_quant_nvfuser_8warp_TN_Legacy/2048/3456/4096/manual_time 1134 us 2892 us 619
Nvfuser_Matmul_8warp3stage/no_quant_nvfuser_8warp_NT_Legacy/2048/3456/4096/manual_time 1320 us 3076 us 532
Nvfuser_Matmul_8warp4stage/no_quant_nvfuser_8warp_TT_Legacy/2048/3456/4096/manual_time 1216 us 2973 us 578
Nvfuser_Matmul_8warp4stage/no_quant_nvfuser_8warp_TN_Legacy/2048/3456/4096/manual_time 1114 us 2872 us 550
Nvfuser_Matmul_8warp4stage/no_quant_nvfuser_8warp_NT_Legacy/2048/3456/4096/manual_time 1371 us 3130 us 512
EagerModeMatmul/no_quant_eagermode_TT_Legacy/2048/3456/4096/manual_time 845 us 912 us 832
EagerModeMatmul/no_quant_eagermode_TN_Legacy/2048/3456/4096/manual_time 916 us 985 us 765
EagerModeMatmul/no_quant_eagermode_NT_Legacy/2048/3456/4096/manual_time 792 us 863 us 884
After this PR:
---------------------------------------------------------------------------------------------------------------------------------
Benchmark Time CPU Iterations
---------------------------------------------------------------------------------------------------------------------------------
Nvfuser_Matmul_4warp3stage/no_quant_nvfuser_4warp_TT_Legacy/2048/3456/4096/manual_time 887 us 2643 us 793
Nvfuser_Matmul_4warp3stage/no_quant_nvfuser_4warp_TN_Legacy/2048/3456/4096/manual_time 899 us 2655 us 780
Nvfuser_Matmul_4warp3stage/no_quant_nvfuser_4warp_NT_Legacy/2048/3456/4096/manual_time 978 us 2734 us 717
Nvfuser_Matmul_4warp4stage/no_quant_nvfuser_4warp_TT_Legacy/2048/3456/4096/manual_time 891 us 2648 us 787
Nvfuser_Matmul_4warp4stage/no_quant_nvfuser_4warp_TN_Legacy/2048/3456/4096/manual_time 899 us 2655 us 782
Nvfuser_Matmul_4warp4stage/no_quant_nvfuser_4warp_NT_Legacy/2048/3456/4096/manual_time 997 us 2753 us 704
Nvfuser_Matmul_8warp3stage/no_quant_nvfuser_8warp_TT_Legacy/2048/3456/4096/manual_time 904 us 2662 us 732
Nvfuser_Matmul_8warp3stage/no_quant_nvfuser_8warp_TN_Legacy/2048/3456/4096/manual_time 903 us 2660 us 778
Nvfuser_Matmul_8warp3stage/no_quant_nvfuser_8warp_NT_Legacy/2048/3456/4096/manual_time 946 us 2703 us 742
Nvfuser_Matmul_8warp4stage/no_quant_nvfuser_8warp_TT_Legacy/2048/3456/4096/manual_time 888 us 2646 us 727
Nvfuser_Matmul_8warp4stage/no_quant_nvfuser_8warp_TN_Legacy/2048/3456/4096/manual_time 885 us 2643 us 794
Nvfuser_Matmul_8warp4stage/no_quant_nvfuser_8warp_NT_Legacy/2048/3456/4096/manual_time 945 us 2702 us 744
EagerModeMatmul/no_quant_eagermode_TT_Legacy/2048/3456/4096/manual_time 845 us 911 us 832
EagerModeMatmul/no_quant_eagermode_TN_Legacy/2048/3456/4096/manual_time 916 us 985 us 765
EagerModeMatmul/no_quant_eagermode_NT_Legacy/2048/3456/4096/manual_time 792 us 863 us 884
Introduction
Loop rotation is a lowering pass that transform
into
In the matmul kernel, both the
cp.async
and theld.matrix
are circular/double buffered. This PR applies loop rotation to the matmul main loop to pull the first iteration'sld.matrix
out of the main loop ofcp.async
.That is, to change the code from
to
In order to do so, I need to do a reorder to change the matmul schedule from
to
Because in the first schedule, the loop structure is
where inside the
cp.async
circular buffer loop, the entireld.matrix->mma
is contained in thethreadIdx
trivial loop, and theld.matrix
is not separable.In contrast, for the second schedule, we have
The
blockIdx
andthreadIdx
loops are trivial loops, so this schedule change actually doesn't affect the generated CUDA kernel. However, it does make kernel IR easier to deal with.Benchmark
Using command
Before this PR:
After this PR: