NVIDIA / cutlass

CUDA Templates for Linear Algebra Subroutines
Other
5.54k stars 943 forks source link

[QST] Performance Issue of doing GEMM on A100 using CuTe #1858

Open Yanksi opened 1 week ago

Yanksi commented 1 week ago

Hi, I've just created a small project (link to the project) by modifying the sgemm_sm80 example. What I was doing was trying to make use of the tensor cores for doing the computation. Unfortunately, when testing on A100, the performance seems never being able to reach the peak performance. Following is the best results I've got from the autotuning process. The peak performance reached seems to be always a bit less than half of the theoretical peak performance provided by A100. Any comments on how can I make this better?


Results for float
                         name       TN  NT
523  gemm_float_config1672_TN  63901.4 NaN
593  gemm_float_config2318_TN  62808.7 NaN
769  gemm_float_config1708_TN  61919.7 NaN
680  gemm_float_config2354_TN  61474.6 NaN
895  gemm_float_config1702_TN  60324.8 NaN
                          name  TN       NT
1871  gemm_float_config1449_NT NaN  62819.8
2047  gemm_float_config2390_NT NaN  60782.0
1857  gemm_float_config1412_NT NaN  60040.9
1902  gemm_float_config1448_NT NaN  58810.4
2309  gemm_float_config1376_NT NaN  58670.4

Results for half
                         name        TN  NT
402   gemm_half_config2750_TN  155641.2 NaN
982   gemm_half_config2752_TN  154693.0 NaN
749   gemm_half_config1804_TN  147818.7 NaN
194   gemm_half_config2456_TN  146790.2 NaN
1054  gemm_half_config1808_TN  145774.8 NaN
                         name  TN       NT
1844  gemm_half_config1840_NT NaN  92024.8
2407  gemm_half_config2492_NT NaN  91585.4
1225  gemm_half_config1844_NT NaN  91393.5
1885  gemm_half_config1808_NT NaN  90688.7
1461  gemm_half_config1514_NT NaN  90666.0
ccecka commented 1 week ago

Your rewrite of sgemm_sm80 with the extra tiling of mma_k and extra loop is unnecessary and could inhibit the compiler from properly pipelining smem and rmem and mma. I don't recommend that.

Beyond that, it appears that you're only using row-major/col-major smem, SIMT smem->rmem TiledCopys, and trivial TiledMMAs. I agree that we should include a non-trivial SM80 tensor-core example in the tutorial...

I did need to modify the sgemm_sm80 kernel slightly to extend smem and got this performance

** New HALF_T for all TA, TB, TC, TI **
$ ./sgemm_sm80 5120 5120 4096 T N
Using device 0: NVIDIA A100-PCIE-40GB (SM80, 108)
M = 5120
N = 5120
K = 4096
C = A^T B^N
CUTE_GEMM:     [126161.8]GFlop/s  (1.7022)ms

The next step would be to use LDSM for the smem->rmem load (this smem layout is already designed for the LDSM pattern... and is what is slowing it down now), which I've left notes on and we could look at CUTLASS's SM80 Collective to see how that's done. That should achieve speed-of-light, peak performance on A100.

Here's my diff/patch/half_t configuration

diff --git a/cutlass/examples/cute/tutorial/sgemm_sm80.cu b/cutlass/examples/cute/tutorial/sgemm_sm80.cu
index d042933040..69b8457c1c 100644
--- a/cutlass/examples/cute/tutorial/sgemm_sm80.cu
+++ b/cutlass/examples/cute/tutorial/sgemm_sm80.cu
@@ -46,6 +46,16 @@
 #include "cutlass/util/helper_cuda.hpp"

+template <class ElementA,
+          class ElementB,
+          class SmemLayoutA,
+          class SmemLayoutB>
+struct SharedStorage
+{
+  cute::array<ElementA, cute::cosize_v<SmemLayoutA>> A;
+  cute::array<ElementB, cute::cosize_v<SmemLayoutB>> B;
+};
+
 template <class ProblemShape, class CtaTiler,
           class TA, class AStride, class ASmemLayout, class TiledCopyA,
           class TB, class BStride, class BSmemLayout, class TiledCopyB,
@@ -100,10 +110,11 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
   Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1,_1, X>{});  // (BLK_M,BLK_N)

   // Shared memory buffers
-  __shared__ TA smemA[cosize_v<ASmemLayout>];
-  __shared__ TB smemB[cosize_v<BSmemLayout>];
-  Tensor sA = make_tensor(make_smem_ptr(smemA), sA_layout);            // (BLK_M,BLK_K,PIPE)
-  Tensor sB = make_tensor(make_smem_ptr(smemB), sB_layout);            // (BLK_N,BLK_K,PIPE)
+  extern __shared__ char shared_memory[];
+  using SharedStorage = SharedStorage<TA, TB, ASmemLayout, BSmemLayout>;
+  SharedStorage& smem = *reinterpret_cast<SharedStorage*>(shared_memory);
+  Tensor sA = make_tensor(make_smem_ptr(smem.A.data()), sA_layout);    // (BLK_M,BLK_K,PIPE)
+  Tensor sB = make_tensor(make_smem_ptr(smem.B.data()), sB_layout);    // (BLK_N,BLK_K,PIPE)

   //
   // Partition the copying of A and B tiles across the threads
@@ -301,6 +312,127 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
   axpby(alpha, tCrC, beta, tCgC);
 }

+template <class Alpha, class Beta>
+void
+gemm_nt(int m, int n, int k,
+        Alpha alpha,
+        cute::half_t const* A, int ldA,
+        cute::half_t const* B, int ldB,
+        Beta beta,
+        cute::half_t      * C, int ldC,
+        cudaStream_t stream = 0)
+{
+  assert(false && "Not implemented");
+}
+
+// Setup params for a TN HGEMM
+template <class Alpha, class Beta>
+void
+gemm_tn(int m, int n, int k,
+        Alpha alpha,
+        cute::half_t const* A, int ldA,
+        cute::half_t const* B, int ldB,
+        Beta beta,
+        cute::half_t      * C, int ldC,
+        cudaStream_t stream = 0)
+{
+  using namespace cute;
+
+  // Define shapes (dynamic)
+  auto M = int(m);
+  auto N = int(n);
+  auto K = int(k);
+  auto prob_shape = make_shape(M, N, K);                     // (M, N, K)
+
+  // Define TN strides (mixed)
+  auto dA = make_stride(ldA, Int<1>{});                      // (dM, dK)
+  auto dB = make_stride(ldB, Int<1>{});                      // (dN, dK)
+  auto dC = make_stride(Int<1>{}, ldC);                      // (dM, dN)
+
+  // Define CTA tile sizes (static)
+  auto bM = Int<128>{};
+  auto bN = Int<128>{};
+  auto bK = Int< 64>{};
+  auto cta_tiler = make_shape(bM, bN, bK);                   // (BLK_M, BLK_N, BLK_K)
+  auto bP = Int<3>{};  // Pipeline
+
+  // Define the smem layouts (static)
+  // Swizzles for LDSM and 128b k-major loads
+  auto swizzle_atom = composition(Swizzle<3,3,3>{},
+                                  Layout<Shape <_8,Shape <_8, _8>>,
+                                         Stride<_8,Stride<_1,_64>>>{});
+  auto sA = tile_to_shape(swizzle_atom, make_shape(bM,bK,bP));
+  auto sB = tile_to_shape(swizzle_atom, make_shape(bN,bK,bP));
+  auto sC = make_layout(make_shape(bM, bN));
+
+  // Define the thread layouts (static)
+
+  TiledCopy copyA = make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, cute::half_t>{},
+                                    Layout<Shape<_32,_8>,Stride<_8,_1>>{},  // Thr layout 32x8 k-major
+                                    Layout<Shape< _1,_8>>{});               // Val layout  1x8 k-major
+  TiledCopy copyB = make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, cute::half_t>{},
+                                    Layout<Shape<_32,_8>,Stride<_8,_1>>{},  // Thr layout 32x8 k-major
+                                    Layout<Shape< _1,_8>>{});               // Val layout  1x8 n-major
+
+  TiledMMA mmaC = make_tiled_mma(SM80_16x8x16_F16F16F16F16_TN{},
+                                 Layout<Shape<_2,_4>>{},    // 2x4x1 MMA Atoms
+                                 Tile<_32,_64,_16>{});      // 32x64x16 Tiled MMA for LDSM
+
+  // //Copy_Atom<DefaultCopy, half_t> copy_atom_A;
+  // //Copy_Atom<UniversalCopy<half_t>, half_t> copy_atom_A;
+  // //Copy_Atom<SM75_U32x1_LDSM_N, half_t> copy_atom_A;
+  // //Copy_Atom<SM75_U32x2_LDSM_N, half_t> copy_atom_A;
+  // Copy_Atom<SM75_U32x4_LDSM_N, half_t> copy_atom_A;
+  // TiledCopy copyA = make_tiled_copy_A(copy_atom_A, mmaC);
+
+  // //Copy_Atom<DefaultCopy, half_t> copy_atom_B;
+  // //Copy_Atom<UniversalCopy<half_t>, half_t> copy_atom_B;
+  // //Copy_Atom<SM75_U32x1_LDSM_N, half_t> copy_atom_B;
+  // //Copy_Atom<SM75_U32x2_LDSM_N, half_t> copy_atom_B;
+  // Copy_Atom<SM75_U32x4_LDSM_N, half_t> copy_atom_B;
+  // TiledCopy copyB = make_tiled_copy_B(copy_atom_B, mmaC);
+
+#if 0
+  print(copyA);
+  print(copyB);
+  print(mmaC);
+#endif
+
+#if 0
+  print_latex(copyA);
+  print_latex(copyB);
+  print_latex(mmaC);
+#endif
+
+  int smem_size = int(sizeof(SharedStorage<cute::half_t, cute::half_t, decltype(sA), decltype(sB)>));
+  dim3 dimBlock(size(mmaC));
+  dim3 dimGrid(size(ceil_div(M, bM)),
+               size(ceil_div(N, bN)));
+
+  auto kernel_fptr = gemm_device<
+    decltype(prob_shape), decltype(cta_tiler),
+    cute::half_t, decltype(dA), decltype(sA), decltype(copyA),
+    cute::half_t, decltype(dB), decltype(sB), decltype(copyB),
+    cute::half_t, decltype(dC), decltype(sC), decltype(mmaC),
+    decltype(alpha), decltype(beta)>;
+
+  // Set L1 to be SMEM only
+  cudaFuncSetAttribute(
+    kernel_fptr,
+    cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
+
+  cudaFuncSetAttribute(
+    kernel_fptr,
+    cudaFuncAttributePreferredSharedMemoryCarveout, 100);
+
+  kernel_fptr<<<dimGrid, dimBlock, smem_size, stream>>>
+      (prob_shape, cta_tiler,
+       A, dA, sA, copyA,
+       B, dB, sB, copyB,
+       C, dC, sC, mmaC,
+       alpha, beta);
+}
+
 // Setup params for a NT GEMM
 template <class TA, class TB, class TC,
           class Alpha, class Beta>
@@ -362,10 +494,11 @@ gemm_nt(int m, int n, int k,
   print_latex(mmaC);
 #endif

+  int smem_size = int(sizeof(SharedStorage<TA, TB, decltype(sA), decltype(sB)>));
   dim3 dimBlock(size(mmaC));
   dim3 dimGrid(size(ceil_div(M, bM)),
                size(ceil_div(N, bN)));
-  gemm_device<<<dimGrid, dimBlock, 0, stream>>>
+  gemm_device<<<dimGrid, dimBlock, smem_size, stream>>>
       (prob_shape, cta_tiler,
        A, dA, sA, copyA,
        B, dB, sB, copyB,
@@ -438,10 +571,11 @@ gemm_tn(int m, int n, int k,
   print_latex(mmaC);
 #endif

+  int smem_size = int(sizeof(SharedStorage<TA, TB, decltype(sA), decltype(sB)>));
   dim3 dimBlock(size(mmaC));
   dim3 dimGrid(size(ceil_div(M, bM)),
                size(ceil_div(N, bN)));
-  gemm_device<<<dimGrid, dimBlock, 0, stream>>>
+  gemm_device<<<dimGrid, dimBlock, smem_size, stream>>>
       (prob_shape, cta_tiler,
        A, dA, sA, copyA,
        B, dB, sB, copyB,
@@ -485,6 +619,11 @@ int main(int argc, char** argv)
     return 0;
   }

+  std::cout << "Using device 0: " << props.name
+            << " (SM" << props.major * 10 + props.minor
+            << ", " << props.multiProcessorCount
+            << ")" << std::endl;
+
   int m = 5120;
   if (argc >= 2)
     sscanf(argv[1], "%d", &m);
@@ -505,13 +644,13 @@ int main(int argc, char** argv)
   if (argc >= 6)
     sscanf(argv[5], "%c", &transB);

-  using TA = float;
-  using TB = float;
-  using TC = float;
-  using TI = float;
+  using TA = cute::half_t;
+  using TB = cute::half_t;
+  using TC = cute::half_t;
+  using TI = cute::half_t;

-  TI alpha = 1.0;
-  TI beta  = 0.0;
+  TI alpha = static_cast<TI>(1.0f);
+  TI beta  = static_cast<TI>(0.0f);

   std::cout << "M = " << m << std::endl;
   std::cout << "N = " << n << std::endl;
Yanksi commented 1 week ago

I updated my project these days. I have included a README.md for better documentation (hopefully).

Your rewrite of sgemm_sm80 with the extra tiling of mma_k and extra loop is unnecessary and could inhibit the compiler from properly pipelining smem and rmem and mma. I don't recommend that.

The extra tiling of mma_k was to match up the smem->rmem pipeline shown in the sgemm_sm80 example. I'm not sure whether it will prevent the compiler from doing the optimization, but the performance did not drop after that was implemented from my experience.

The best performance I got on A100 when using FP16 datatype after autotuning is 155641.2 GFlops/s. On the other hand, the reference cublas gemm gave me a performance about 22 TFlops (Sorry I cannot recall the exact number and I am not able to use an A100 recently).

I am running out of the ideas for doing further optimization at this point. If anyone can take a look at my code and figure out what could the next optimization be, I would be greatly appreciated!!

ccecka commented 1 week ago

I recall something about SM80_16x8x16_F16F16F16F16_TN being significantly more difficult to optimize than SM80_16x8x16_F32F16F16F32_TN, though I forget the details as it's been a while since deep work on Ampere.

Regardless, to give you an example of the same SM80 kernel using LDSM, I've attached my CuTe example for half_t. This kernel should be equivalent to the CUTLASS SM80 Collective implementation and accept very similar configuration parameters.

sgemm_sm80_tmp.txt

Yanksi commented 1 week ago

Thanks a lot for your quick respond!!

I noticed that a class called Swizzle is used in your implementation. I think it is supposed to be used for avoiding bank conflicts when doing smem->rmem. However, I am not able to find the documentation for that class. Could you please explain a bit what exactly is

auto swizzle_atom = composition(Swizzle<3,3,3>{},
                                  Layout<Shape <_8,Shape <_8, _8>>,
                                         Stride<_8,Stride<_1,_64>>>{});

doing?

About those LDSM copying atoms, I think their naming conventions follows as SM75_<dtype>x<#items>_LDSM_<layout>? If I understand this correctly, then why in you were using <dtype>=U32 in your code if we are dealing with half? Also, you were using the same layout for both A and B in gemm_nt, which from my perspective, it should different to make more sense? Another question of mine regarding to those LDSM atoms is that what's the difference between these atoms with different <#items> exactly? And what would be the underlying layout for copying in this case?

I noticed the definition of those layout in copy_traits_sm75.hpp. And I am confused about the meaning of // Map from (src-thr,src-val) to bit.

ccecka commented 6 days ago

print_latex will answer all of those questions. Call it on Layouts, TiledCopy, TiledMMA, etc.

ghostplant commented 7 hours ago

I recall something about SM80_16x8x16_F16F16F16F16_TN being significantly more difficult to optimize than SM80_16x8x16_F32F16F16F32_TN, though I forget the details as it's been a while since deep work on Ampere.

Regardless, to give you an example of the same SM80 kernel using LDSM, I've attached my CuTe example for half_t. This kernel should be equivalent to the CUTLASS SM80 Collective implementation and accept very similar configuration parameters.

sgemm_sm80_tmp.txt

Do you think it can be further improved? The example your provided is currently 150TFlops on A100, while CUBLAS gets 300TFlops.

Yanksi commented 7 hours ago

I found this high performance implementation of gemm using cute. The author wrote a series of tutorial for CuTe (in Chinese) on Zhihu. And in one of the tutorials, the author claimed this implementation have reached a CUBLAS level performance on RTX 3090. I am not sure whether it will be the same case on A100 as I currently don't have access to it. I think this series of tutorial is a very good complement to the official CuTe's documentation.

ghostplant commented 6 hours ago

I found this high performance implementation of gemm using cute. The author wrote a series of tutorial for CuTe (in Chinese) on Zhihu. And in one of the tutorials, the author claimed this implementation have reached a CUBLAS level performance on RTX 3090. I am not sure whether it will be the same case on A100 as I currently don't have access to it. I think this series of tutorial is a very good complement to the official CuTe's documentation.

I tried that before, looks like there are compilation issues due to interface no longer match with latest cutlass's CUTE