Open Yanksi opened 1 week ago
Your rewrite of sgemm_sm80
with the extra tiling of mma_k
and extra loop is unnecessary and could inhibit the compiler from properly pipelining smem and rmem and mma. I don't recommend that.
Beyond that, it appears that you're only using row-major/col-major smem, SIMT smem->rmem TiledCopy
s, and trivial TiledMMA
s. I agree that we should include a non-trivial SM80 tensor-core example in the tutorial...
I did need to modify the sgemm_sm80
kernel slightly to extend smem and got this performance
** New HALF_T for all TA, TB, TC, TI **
$ ./sgemm_sm80 5120 5120 4096 T N
Using device 0: NVIDIA A100-PCIE-40GB (SM80, 108)
M = 5120
N = 5120
K = 4096
C = A^T B^N
CUTE_GEMM: [126161.8]GFlop/s (1.7022)ms
The next step would be to use LDSM for the smem->rmem load (this smem layout is already designed for the LDSM pattern... and is what is slowing it down now), which I've left notes on and we could look at CUTLASS's SM80 Collective to see how that's done. That should achieve speed-of-light, peak performance on A100.
Here's my diff/patch/half_t configuration
diff --git a/cutlass/examples/cute/tutorial/sgemm_sm80.cu b/cutlass/examples/cute/tutorial/sgemm_sm80.cu
index d042933040..69b8457c1c 100644
--- a/cutlass/examples/cute/tutorial/sgemm_sm80.cu
+++ b/cutlass/examples/cute/tutorial/sgemm_sm80.cu
@@ -46,6 +46,16 @@
#include "cutlass/util/helper_cuda.hpp"
+template <class ElementA,
+ class ElementB,
+ class SmemLayoutA,
+ class SmemLayoutB>
+struct SharedStorage
+{
+ cute::array<ElementA, cute::cosize_v<SmemLayoutA>> A;
+ cute::array<ElementB, cute::cosize_v<SmemLayoutB>> B;
+};
+
template <class ProblemShape, class CtaTiler,
class TA, class AStride, class ASmemLayout, class TiledCopyA,
class TB, class BStride, class BSmemLayout, class TiledCopyB,
@@ -100,10 +110,11 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N)
// Shared memory buffers
- __shared__ TA smemA[cosize_v<ASmemLayout>];
- __shared__ TB smemB[cosize_v<BSmemLayout>];
- Tensor sA = make_tensor(make_smem_ptr(smemA), sA_layout); // (BLK_M,BLK_K,PIPE)
- Tensor sB = make_tensor(make_smem_ptr(smemB), sB_layout); // (BLK_N,BLK_K,PIPE)
+ extern __shared__ char shared_memory[];
+ using SharedStorage = SharedStorage<TA, TB, ASmemLayout, BSmemLayout>;
+ SharedStorage& smem = *reinterpret_cast<SharedStorage*>(shared_memory);
+ Tensor sA = make_tensor(make_smem_ptr(smem.A.data()), sA_layout); // (BLK_M,BLK_K,PIPE)
+ Tensor sB = make_tensor(make_smem_ptr(smem.B.data()), sB_layout); // (BLK_N,BLK_K,PIPE)
//
// Partition the copying of A and B tiles across the threads
@@ -301,6 +312,127 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
axpby(alpha, tCrC, beta, tCgC);
}
+template <class Alpha, class Beta>
+void
+gemm_nt(int m, int n, int k,
+ Alpha alpha,
+ cute::half_t const* A, int ldA,
+ cute::half_t const* B, int ldB,
+ Beta beta,
+ cute::half_t * C, int ldC,
+ cudaStream_t stream = 0)
+{
+ assert(false && "Not implemented");
+}
+
+// Setup params for a TN HGEMM
+template <class Alpha, class Beta>
+void
+gemm_tn(int m, int n, int k,
+ Alpha alpha,
+ cute::half_t const* A, int ldA,
+ cute::half_t const* B, int ldB,
+ Beta beta,
+ cute::half_t * C, int ldC,
+ cudaStream_t stream = 0)
+{
+ using namespace cute;
+
+ // Define shapes (dynamic)
+ auto M = int(m);
+ auto N = int(n);
+ auto K = int(k);
+ auto prob_shape = make_shape(M, N, K); // (M, N, K)
+
+ // Define TN strides (mixed)
+ auto dA = make_stride(ldA, Int<1>{}); // (dM, dK)
+ auto dB = make_stride(ldB, Int<1>{}); // (dN, dK)
+ auto dC = make_stride(Int<1>{}, ldC); // (dM, dN)
+
+ // Define CTA tile sizes (static)
+ auto bM = Int<128>{};
+ auto bN = Int<128>{};
+ auto bK = Int< 64>{};
+ auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K)
+ auto bP = Int<3>{}; // Pipeline
+
+ // Define the smem layouts (static)
+ // Swizzles for LDSM and 128b k-major loads
+ auto swizzle_atom = composition(Swizzle<3,3,3>{},
+ Layout<Shape <_8,Shape <_8, _8>>,
+ Stride<_8,Stride<_1,_64>>>{});
+ auto sA = tile_to_shape(swizzle_atom, make_shape(bM,bK,bP));
+ auto sB = tile_to_shape(swizzle_atom, make_shape(bN,bK,bP));
+ auto sC = make_layout(make_shape(bM, bN));
+
+ // Define the thread layouts (static)
+
+ TiledCopy copyA = make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, cute::half_t>{},
+ Layout<Shape<_32,_8>,Stride<_8,_1>>{}, // Thr layout 32x8 k-major
+ Layout<Shape< _1,_8>>{}); // Val layout 1x8 k-major
+ TiledCopy copyB = make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, cute::half_t>{},
+ Layout<Shape<_32,_8>,Stride<_8,_1>>{}, // Thr layout 32x8 k-major
+ Layout<Shape< _1,_8>>{}); // Val layout 1x8 n-major
+
+ TiledMMA mmaC = make_tiled_mma(SM80_16x8x16_F16F16F16F16_TN{},
+ Layout<Shape<_2,_4>>{}, // 2x4x1 MMA Atoms
+ Tile<_32,_64,_16>{}); // 32x64x16 Tiled MMA for LDSM
+
+ // //Copy_Atom<DefaultCopy, half_t> copy_atom_A;
+ // //Copy_Atom<UniversalCopy<half_t>, half_t> copy_atom_A;
+ // //Copy_Atom<SM75_U32x1_LDSM_N, half_t> copy_atom_A;
+ // //Copy_Atom<SM75_U32x2_LDSM_N, half_t> copy_atom_A;
+ // Copy_Atom<SM75_U32x4_LDSM_N, half_t> copy_atom_A;
+ // TiledCopy copyA = make_tiled_copy_A(copy_atom_A, mmaC);
+
+ // //Copy_Atom<DefaultCopy, half_t> copy_atom_B;
+ // //Copy_Atom<UniversalCopy<half_t>, half_t> copy_atom_B;
+ // //Copy_Atom<SM75_U32x1_LDSM_N, half_t> copy_atom_B;
+ // //Copy_Atom<SM75_U32x2_LDSM_N, half_t> copy_atom_B;
+ // Copy_Atom<SM75_U32x4_LDSM_N, half_t> copy_atom_B;
+ // TiledCopy copyB = make_tiled_copy_B(copy_atom_B, mmaC);
+
+#if 0
+ print(copyA);
+ print(copyB);
+ print(mmaC);
+#endif
+
+#if 0
+ print_latex(copyA);
+ print_latex(copyB);
+ print_latex(mmaC);
+#endif
+
+ int smem_size = int(sizeof(SharedStorage<cute::half_t, cute::half_t, decltype(sA), decltype(sB)>));
+ dim3 dimBlock(size(mmaC));
+ dim3 dimGrid(size(ceil_div(M, bM)),
+ size(ceil_div(N, bN)));
+
+ auto kernel_fptr = gemm_device<
+ decltype(prob_shape), decltype(cta_tiler),
+ cute::half_t, decltype(dA), decltype(sA), decltype(copyA),
+ cute::half_t, decltype(dB), decltype(sB), decltype(copyB),
+ cute::half_t, decltype(dC), decltype(sC), decltype(mmaC),
+ decltype(alpha), decltype(beta)>;
+
+ // Set L1 to be SMEM only
+ cudaFuncSetAttribute(
+ kernel_fptr,
+ cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
+
+ cudaFuncSetAttribute(
+ kernel_fptr,
+ cudaFuncAttributePreferredSharedMemoryCarveout, 100);
+
+ kernel_fptr<<<dimGrid, dimBlock, smem_size, stream>>>
+ (prob_shape, cta_tiler,
+ A, dA, sA, copyA,
+ B, dB, sB, copyB,
+ C, dC, sC, mmaC,
+ alpha, beta);
+}
+
// Setup params for a NT GEMM
template <class TA, class TB, class TC,
class Alpha, class Beta>
@@ -362,10 +494,11 @@ gemm_nt(int m, int n, int k,
print_latex(mmaC);
#endif
+ int smem_size = int(sizeof(SharedStorage<TA, TB, decltype(sA), decltype(sB)>));
dim3 dimBlock(size(mmaC));
dim3 dimGrid(size(ceil_div(M, bM)),
size(ceil_div(N, bN)));
- gemm_device<<<dimGrid, dimBlock, 0, stream>>>
+ gemm_device<<<dimGrid, dimBlock, smem_size, stream>>>
(prob_shape, cta_tiler,
A, dA, sA, copyA,
B, dB, sB, copyB,
@@ -438,10 +571,11 @@ gemm_tn(int m, int n, int k,
print_latex(mmaC);
#endif
+ int smem_size = int(sizeof(SharedStorage<TA, TB, decltype(sA), decltype(sB)>));
dim3 dimBlock(size(mmaC));
dim3 dimGrid(size(ceil_div(M, bM)),
size(ceil_div(N, bN)));
- gemm_device<<<dimGrid, dimBlock, 0, stream>>>
+ gemm_device<<<dimGrid, dimBlock, smem_size, stream>>>
(prob_shape, cta_tiler,
A, dA, sA, copyA,
B, dB, sB, copyB,
@@ -485,6 +619,11 @@ int main(int argc, char** argv)
return 0;
}
+ std::cout << "Using device 0: " << props.name
+ << " (SM" << props.major * 10 + props.minor
+ << ", " << props.multiProcessorCount
+ << ")" << std::endl;
+
int m = 5120;
if (argc >= 2)
sscanf(argv[1], "%d", &m);
@@ -505,13 +644,13 @@ int main(int argc, char** argv)
if (argc >= 6)
sscanf(argv[5], "%c", &transB);
- using TA = float;
- using TB = float;
- using TC = float;
- using TI = float;
+ using TA = cute::half_t;
+ using TB = cute::half_t;
+ using TC = cute::half_t;
+ using TI = cute::half_t;
- TI alpha = 1.0;
- TI beta = 0.0;
+ TI alpha = static_cast<TI>(1.0f);
+ TI beta = static_cast<TI>(0.0f);
std::cout << "M = " << m << std::endl;
std::cout << "N = " << n << std::endl;
I updated my project these days. I have included a README.md
for better documentation (hopefully).
Your rewrite of sgemm_sm80 with the extra tiling of mma_k and extra loop is unnecessary and could inhibit the compiler from properly pipelining smem and rmem and mma. I don't recommend that.
The extra tiling of mma_k
was to match up the smem->rmem
pipeline shown in the sgemm_sm80
example. I'm not sure whether it will prevent the compiler from doing the optimization, but the performance did not drop after that was implemented from my experience.
The best performance I got on A100 when using FP16 datatype after autotuning is 155641.2 GFlops/s
. On the other hand, the reference cublas gemm gave me a performance about 22 TFlops (Sorry I cannot recall the exact number and I am not able to use an A100 recently).
I am running out of the ideas for doing further optimization at this point. If anyone can take a look at my code and figure out what could the next optimization be, I would be greatly appreciated!!
I recall something about SM80_16x8x16_F16F16F16F16_TN
being significantly more difficult to optimize than SM80_16x8x16_F32F16F16F32_TN
, though I forget the details as it's been a while since deep work on Ampere.
Regardless, to give you an example of the same SM80 kernel using LDSM, I've attached my CuTe example for half_t
. This kernel should be equivalent to the CUTLASS SM80 Collective implementation and accept very similar configuration parameters.
Thanks a lot for your quick respond!!
I noticed that a class called Swizzle
is used in your implementation. I think it is supposed to be used for avoiding bank conflicts when doing smem->rmem
. However, I am not able to find the documentation for that class. Could you please explain a bit what exactly is
auto swizzle_atom = composition(Swizzle<3,3,3>{},
Layout<Shape <_8,Shape <_8, _8>>,
Stride<_8,Stride<_1,_64>>>{});
doing?
About those LDSM copying atoms, I think their naming conventions follows as SM75_<dtype>x<#items>_LDSM_<layout>
? If I understand this correctly, then why in you were using <dtype>=U32
in your code if we are dealing with half
? Also, you were using the same layout for both A and B in gemm_nt
, which from my perspective, it should different to make more sense? Another question of mine regarding to those LDSM atoms is that what's the difference between these atoms with different <#items>
exactly? And what would be the underlying layout for copying in this case?
I noticed the definition of those layout in copy_traits_sm75.hpp. And I am confused about the meaning of // Map from (src-thr,src-val) to bit
.
print_latex
will answer all of those questions. Call it on Layouts, TiledCopy, TiledMMA, etc.
I recall something about
SM80_16x8x16_F16F16F16F16_TN
being significantly more difficult to optimize thanSM80_16x8x16_F32F16F16F32_TN
, though I forget the details as it's been a while since deep work on Ampere.Regardless, to give you an example of the same SM80 kernel using LDSM, I've attached my CuTe example for
half_t
. This kernel should be equivalent to the CUTLASS SM80 Collective implementation and accept very similar configuration parameters.
Do you think it can be further improved? The example your provided is currently 150TFlops on A100, while CUBLAS gets 300TFlops.
I found this high performance implementation of gemm using cute. The author wrote a series of tutorial for CuTe (in Chinese) on Zhihu. And in one of the tutorials, the author claimed this implementation have reached a CUBLAS level performance on RTX 3090. I am not sure whether it will be the same case on A100 as I currently don't have access to it. I think this series of tutorial is a very good complement to the official CuTe's documentation.
I found this high performance implementation of gemm using cute. The author wrote a series of tutorial for CuTe (in Chinese) on Zhihu. And in one of the tutorials, the author claimed this implementation have reached a CUBLAS level performance on RTX 3090. I am not sure whether it will be the same case on A100 as I currently don't have access to it. I think this series of tutorial is a very good complement to the official CuTe's documentation.
I tried that before, looks like there are compilation issues due to interface no longer match with latest cutlass's CUTE
Hi, I've just created a small project (link to the project) by modifying the
sgemm_sm80
example. What I was doing was trying to make use of the tensor cores for doing the computation. Unfortunately, when testing on A100, the performance seems never being able to reach the peak performance. Following is the best results I've got from the autotuning process. The peak performance reached seems to be always a bit less than half of the theoretical peak performance provided by A100. Any comments on how can I make this better?