csarofeen / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
http://pytorch.org
Other
26 stars 7 forks source link

Improve matmul instruction scheduling with loop rotation #2488

Closed zasdfgbnm closed 1 year ago

zasdfgbnm commented 1 year ago

Introduction

Loop rotation is a lowering pass that transform

for i in range(n):
  statement1(i)
  statement2(i)
  statement3(i)
  statement4(i)

into

statement1(0)
statement2(0)
for i in range(n):
  statement3(i)
  statement4(i)
  statement1(i+1)
  statement2(i+1)

In the matmul kernel, both the cp.async and the ld.matrix are circular/double buffered. This PR applies loop rotation to the matmul main loop to pull the first iteration's ld.matrix out of the main loop of cp.async.

That is, to change the code from

cp.async prologue
// main loop for cp.async
for (...) {
  cp.async
  ld.matrix prologue // <-- to be rotated
  // main loop for ld.matrix
  for (...) {
    ld.matrix
    mma
  }
  mma // epilogue for ld.matrix
}

to

cp.async prologue
ld.matrix prologue  // <-- rotated
// main loop for cp.async
for (...) {
  cp.async
  // main loop for ld.matrix
  for (...) {
    ld.matrix
    mma
  }
  mma // epilogue for ld.matrix
  ld.matrix prologue  // <-- rotated
}

In order to do so, I need to do a reorder to change the matmul schedule from

//                               vvvvvv ld.matrix double buffer loop
[BIDx, BIDy, Serial, TIDz, TIDy, Serial]
//           ^^^^^^ cp.async circular buffer loop

to

//                   vvvvvv ld.matrix double buffer loop
[BIDx, BIDy, Serial, Serial, TIDz, TIDy]
//           ^^^^^^ cp.async circular buffer loop

Because in the first schedule, the loop structure is

for blockIdx.x:
  for blockIdx.y:
    cp.async
    for i1:  # cp.async circular buffer loop
      cp.async
      for threadIdx.z:
        for threadIdx.y:
          ld.matrix
          for i2:  # ld.matrix double buffer loop
            ld.matrix
            mma
          mma

where inside the cp.async circular buffer loop, the entire ld.matrix->mma is contained in the threadIdx trivial loop, and the ld.matrix is not separable.

In contrast, for the second schedule, we have

for blockIdx.x:
  for blockIdx.y:
    cp.async
    for i1:  # cp.async circular buffer loop
      cp.async
      ld.matrix
      for i2:  # ld.matrix double buffer loop
        for threadIdx.z:
          for threadIdx.y:
            ld.matrix
            mma
      for threadIdx.z:
        for threadIdx.y:
          mma

The blockIdx and threadIdx loops are trivial loops, so this schedule change actually doesn't affect the generated CUDA kernel. However, it does make kernel IR easier to deal with.

Benchmark

Using command

$CUDA_VISIBLE_DEVICES=1 ./build/bin/nvfuser_bench --benchmark_filter=.*Matmul.*Legacy/2048/3456/4096.*

Before this PR:

---------------------------------------------------------------------------------------------------------------------------------
Benchmark                                                                                       Time             CPU   Iterations
---------------------------------------------------------------------------------------------------------------------------------
Nvfuser_Matmul_4warp3stage/no_quant_nvfuser_4warp_TT_Legacy/2048/3456/4096/manual_time        990 us         2746 us          711
Nvfuser_Matmul_4warp3stage/no_quant_nvfuser_4warp_TN_Legacy/2048/3456/4096/manual_time        871 us         2629 us          720
Nvfuser_Matmul_4warp3stage/no_quant_nvfuser_4warp_NT_Legacy/2048/3456/4096/manual_time       1064 us         2821 us          579
Nvfuser_Matmul_4warp4stage/no_quant_nvfuser_4warp_TT_Legacy/2048/3456/4096/manual_time       1278 us         3034 us          499
Nvfuser_Matmul_4warp4stage/no_quant_nvfuser_4warp_TN_Legacy/2048/3456/4096/manual_time       1159 us         2914 us          607
Nvfuser_Matmul_4warp4stage/no_quant_nvfuser_4warp_NT_Legacy/2048/3456/4096/manual_time       1432 us         3188 us          447
Nvfuser_Matmul_8warp3stage/no_quant_nvfuser_8warp_TT_Legacy/2048/3456/4096/manual_time       1209 us         2966 us          526
Nvfuser_Matmul_8warp3stage/no_quant_nvfuser_8warp_TN_Legacy/2048/3456/4096/manual_time       1134 us         2892 us          619
Nvfuser_Matmul_8warp3stage/no_quant_nvfuser_8warp_NT_Legacy/2048/3456/4096/manual_time       1320 us         3076 us          532
Nvfuser_Matmul_8warp4stage/no_quant_nvfuser_8warp_TT_Legacy/2048/3456/4096/manual_time       1216 us         2973 us          578
Nvfuser_Matmul_8warp4stage/no_quant_nvfuser_8warp_TN_Legacy/2048/3456/4096/manual_time       1114 us         2872 us          550
Nvfuser_Matmul_8warp4stage/no_quant_nvfuser_8warp_NT_Legacy/2048/3456/4096/manual_time       1371 us         3130 us          512
EagerModeMatmul/no_quant_eagermode_TT_Legacy/2048/3456/4096/manual_time                       845 us          912 us          832
EagerModeMatmul/no_quant_eagermode_TN_Legacy/2048/3456/4096/manual_time                       916 us          985 us          765
EagerModeMatmul/no_quant_eagermode_NT_Legacy/2048/3456/4096/manual_time                       792 us          863 us          884

After this PR:

---------------------------------------------------------------------------------------------------------------------------------
Benchmark                                                                                       Time             CPU   Iterations
---------------------------------------------------------------------------------------------------------------------------------
Nvfuser_Matmul_4warp3stage/no_quant_nvfuser_4warp_TT_Legacy/2048/3456/4096/manual_time        887 us         2643 us          793
Nvfuser_Matmul_4warp3stage/no_quant_nvfuser_4warp_TN_Legacy/2048/3456/4096/manual_time        899 us         2655 us          780
Nvfuser_Matmul_4warp3stage/no_quant_nvfuser_4warp_NT_Legacy/2048/3456/4096/manual_time        978 us         2734 us          717
Nvfuser_Matmul_4warp4stage/no_quant_nvfuser_4warp_TT_Legacy/2048/3456/4096/manual_time        891 us         2648 us          787
Nvfuser_Matmul_4warp4stage/no_quant_nvfuser_4warp_TN_Legacy/2048/3456/4096/manual_time        899 us         2655 us          782
Nvfuser_Matmul_4warp4stage/no_quant_nvfuser_4warp_NT_Legacy/2048/3456/4096/manual_time        997 us         2753 us          704
Nvfuser_Matmul_8warp3stage/no_quant_nvfuser_8warp_TT_Legacy/2048/3456/4096/manual_time        904 us         2662 us          732
Nvfuser_Matmul_8warp3stage/no_quant_nvfuser_8warp_TN_Legacy/2048/3456/4096/manual_time        903 us         2660 us          778
Nvfuser_Matmul_8warp3stage/no_quant_nvfuser_8warp_NT_Legacy/2048/3456/4096/manual_time        946 us         2703 us          742
Nvfuser_Matmul_8warp4stage/no_quant_nvfuser_8warp_TT_Legacy/2048/3456/4096/manual_time        888 us         2646 us          727
Nvfuser_Matmul_8warp4stage/no_quant_nvfuser_8warp_TN_Legacy/2048/3456/4096/manual_time        885 us         2643 us          794
Nvfuser_Matmul_8warp4stage/no_quant_nvfuser_8warp_NT_Legacy/2048/3456/4096/manual_time        945 us         2702 us          744
EagerModeMatmul/no_quant_eagermode_TT_Legacy/2048/3456/4096/manual_time                       845 us          911 us          832
EagerModeMatmul/no_quant_eagermode_TN_Legacy/2048/3456/4096/manual_time                       916 us          985 us          765
EagerModeMatmul/no_quant_eagermode_NT_Legacy/2048/3456/4096/manual_time                       792 us          863 us          884