Closed jxybb closed 1 year ago
Hi, @jz0909.
You're correct that CUTLASS currently only provides SM80+ implementations of RanK kernels. However, one can enable RankK kernels to run on SM75 through example 31 through a few changes.
Here's a high-level overview of some of the changes that need to be made, as well as a diff at the end of this comment that shows an example (not the cleanest) of how I got this functional. This answer below assumes the use of FP64 data types.
cutlass::gemm::device::RankK
that match SM75. In particular, you will need to change cutlass::arch::OpClassTensorOp
to cutlass::arch::OpClassSimt
here, since SM75 does not have FP64 Tensor Cores; cutlass::arch::Sm80
to cutlass::arch::Sm75
here to indicate that SM75 is to be used; the GemmShape
parameters to <16, 32, 8>
, <16, 16, 8>
, and <1, 1, 1>
to fit the use of SIMT operations; and the Stages parameter from 5 to 2, since stage count greater than 2 is only supported on SM80+. You will also need to change this line to allow operation on architectures of compute capability 75 or higher.DefaultRankK
configuration here similar to that for SM80, but which is partially specialized on arch::Sm75
, arch::OpClassSimt
, and stage count of 2. Importantly, this will need to use a different epilogue than DefaultEpilogueTensorOpBlas3
used here, which is specific to Tensor Core kernels. We will define such a default epilogue next.DefaultEpilogueSimtBlas3
implementation by combining the default epilogues defined here (specialized for BLAS 3 kernels) and here (specialized for SIMT).Here's a diff showing how to get these together quickly:
diff --git a/examples/31_basic_syrk/basic_syrk.cu b/examples/31_basic_syrk/basic_syrk.cu
index 82f4a6a2..c66b5fa1 100644
--- a/examples/31_basic_syrk/basic_syrk.cu
+++ b/examples/31_basic_syrk/basic_syrk.cu
@@ -100,11 +100,11 @@ cudaError_t CutlassSsyrkNN(
ColumnMajor,
cutlass::FillMode::kLower,
double,
- cutlass::arch::OpClassTensorOp,
- cutlass::arch::Sm80,
- cutlass::gemm::GemmShape<16, 32, 16>,
- cutlass::gemm::GemmShape<16, 16, 16>,
- cutlass::gemm::GemmShape<8, 8, 4>,
+ cutlass::arch::OpClassSimt,
+ cutlass::arch::Sm75,
+ cutlass::gemm::GemmShape<16, 32, 8>,
+ cutlass::gemm::GemmShape<16, 16, 8>,
+ cutlass::gemm::GemmShape<1, 1, 1>,
cutlass::epilogue::thread::LinearCombination<
double,
1,
@@ -112,7 +112,7 @@ cudaError_t CutlassSsyrkNN(
double
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
- 5, // Stages
+ 2, // Stages
1, // AligmentA
false, // SplitKSerail
cutlass::arch::OpMultiplyAdd,
@@ -308,7 +308,7 @@ cudaError_t TestCutlassSyrk(int N, int K, double alpha, double beta) {
// Compute leading dimensions for each matrix.
int lda = N;
- int ldc = N;
+ int ldc = N;
// Compute size in bytes of the C matrix.
size_t sizeof_C = sizeof(double) * ldc * N;
@@ -469,9 +469,9 @@ int main(int argc, const char *arg[]) {
return -1;
}
- if (!((props.major * 10 + props.minor) >= 80)) {
+ if (!((props.major * 10 + props.minor) >= 75)) {
- std::cerr << "This example requires compute capability at least 80."
+ std::cerr << "This example requires compute capability at least 75."
<< std::endl;
notSupported = true;
}
diff --git a/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op_blas3.h b/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op_blas3.h
index aef49616..6ee9a3e0 100644
--- a/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op_blas3.h
+++ b/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op_blas3.h
@@ -57,10 +57,13 @@
#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h"
+#include "cutlass/epilogue/warp/fragment_iterator_simt.h"
#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h"
#include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h"
+#include "cutlass/epilogue/warp/tile_iterator_simt.h"
#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h"
#include "cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h"
+#include "cutlass/epilogue/threadblock/default_thread_map_simt.h"
#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h"
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_blas3.h"
#include "cutlass/epilogue/threadblock/shared_load_iterator.h"
@@ -168,6 +171,87 @@ struct DefaultEpilogueTensorOpBlas3 {
////////////////////////////////////////////////////////////////////////////////
+/// Defines sensible defaults for epilogues for Simt.
+template <
+ typename Shape_,
+ typename WarpMmaSimt_,
+ // int PartitionsK,
+ typename OutputOp_,
+ int ElementsPerAccess,
+ /// Is for a symmetric kernel
+ BlasMode BlasMode_ = BlasMode::kGemm
+>
+struct DefaultEpilogueSimtBlas3 {
+
+ using Shape = Shape_;
+ using WarpMmaSimt = WarpMmaSimt_;
+ static const int kPartitionsK = Shape::kK / WarpMmaSimt::Shape::kK;
+ using OutputOp = OutputOp_;
+ static int const kElementsPerAccess = ElementsPerAccess;
+ static BlasMode const kBlasMode = BlasMode_;
+
+ using ElementOutput = typename OutputOp::ElementOutput;
+ using LayoutC = typename WarpMmaSimt::LayoutC;
+ using ElementAccumulator = typename WarpMmaSimt::ElementC;
+
+ //
+ // Thread map
+ //
+
+ using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapSimt<
+ Shape,
+ typename WarpMmaSimt::Shape,
+ typename WarpMmaSimt::Policy,
+ kPartitionsK,
+ ElementOutput,
+ kElementsPerAccess
+ >::Type;
+
+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorBlas3<
+ OutputTileThreadMap,
+ ElementOutput,
+ kBlasMode
+ >;
+
+ using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorSimt<
+ typename WarpMmaSimt::Shape,
+ typename WarpMmaSimt::ThreadMma,
+ layout::RowMajor,
+ typename WarpMmaSimt::Policy
+ >;
+
+ using WarpTileIterator = cutlass::epilogue::warp::TileIteratorSimt<
+ typename WarpMmaSimt::Shape,
+ typename WarpMmaSimt::ThreadMma,
+ ElementAccumulator,
+ layout::RowMajor,
+ typename WarpMmaSimt::Policy
+ >;
+
+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator<
+ typename OutputTileThreadMap::CompactedThreadMap,
+ ElementAccumulator
+ >;
+
+ /// Hard-coded padding elements added
+ using Padding = cutlass::MatrixShape<0, 64 / sizeof_bits<ElementAccumulator>::value * 4>;
+
+ //
+ // Define the epilogue
+ //
+ using Epilogue = cutlass::epilogue::threadblock::Epilogue<
+ Shape,
+ WarpMmaSimt,
+ kPartitionsK,
+ OutputTileIterator,
+ AccumulatorFragmentIterator,
+ WarpTileIterator,
+ SharedLoadIterator,
+ OutputOp,
+ Padding
+ >;
+};
+
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass
diff --git a/include/cutlass/gemm/kernel/default_rank_k.h b/include/cutlass/gemm/kernel/default_rank_k.h
index 2c0c7a85..850bd304 100644
--- a/include/cutlass/gemm/kernel/default_rank_k.h
+++ b/include/cutlass/gemm/kernel/default_rank_k.h
@@ -241,7 +241,67 @@ struct DefaultRankK<
};
////////////////////////////////////////////////////////////////////////////////
+////////////////////////////////////////////////////////////////////////////////
+/// Partial specialization for Simt Architecture
+template <
+ /// Element type for A matrix operand
+ typename ElementA,
+ /// Layout type for A matrix operand
+ typename LayoutA,
+ /// Access granularity of A matrix in units of elements
+ int kAlignmentA,
+ /// Element type for C and D matrix operands
+ typename ElementC,
+ /// Fill Mode for C (kLower or kUpper)
+ FillMode FillModeC,
+ /// Element type for internal accumulation
+ typename ElementAccumulator,
+ /// Threadblock-level tile size (concept: GemmShape)
+ typename ThreadblockShape,
+ /// Warp-level tile size (concept: GemmShape)
+ typename WarpShape,
+ /// Warp-level tile size (concept: GemmShape)
+ typename InstructionShape,
+ /// Epilogue output operator
+ typename EpilogueOutputOp,
+ /// Threadblock-level swizzling operator
+ typename ThreadblockSwizzle,
+ /// If true, kernel is configured to support serial reduction in the
+ /// epilogue
+ bool SplitKSerial,
+ /// Operation performed by GEMM
+ typename Operator>
+struct DefaultRankK<
+ ElementA, LayoutA, kAlignmentA,
+ ElementC,layout::RowMajor, FillModeC,
+ ElementAccumulator, arch::OpClassSimt, arch::Sm75,
+ ThreadblockShape, WarpShape, InstructionShape,
+ EpilogueOutputOp, ThreadblockSwizzle, 2, SplitKSerial,
+ Operator> {
+ /// Define the threadblock-scoped matrix multiply-accumulate (A x AT)
+ using Mma = typename cutlass::gemm::threadblock::DefaultMma<
+ ElementA, LayoutA,
+ kAlignmentA,
+ ElementA, typename layout::LayoutTranspose<LayoutA>::type,
+ kAlignmentA,
+ ElementAccumulator, layout::RowMajor, arch::OpClassSimt, arch::Sm75,
+ ThreadblockShape, WarpShape, InstructionShape, 2,
+ Operator>::ThreadblockMma;
+
+
+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
+
+ /// Define the epilogue
+ using Epilogue =
+ typename cutlass::epilogue::threadblock::DefaultEpilogueSimtBlas3<
+ ThreadblockShape, typename Mma::Operator, EpilogueOutputOp,
+ EpilogueOutputOp::kCount, BlasMode::kSymmetric>::Epilogue;
+
+ /// Define the kernel-level Rank2 operator.
+ using RankKkernel = kernel::RankKUniversal<Mma, Epilogue, ThreadblockSwizzle, FillModeC>;
+};
+////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
Thanks much for your timely, detailed and super clear answer, @jackkosaian! May I ask a follow up question about the rank_k kernel? We know that sm75 devices also have Tensor Cores for mixed precision matmul. Can I use FP16 data for the rank_k kernel and get FP32 results?
Can I use FP16 data for the rank_k kernel and get FP32 results?
I read this question a couple of different ways, depending on what is meant by "get FP32 results," so I'd like to clarify to make sure we're on the same page.
If the intended interpretation of this question is along the lines of "Can I achieve FP32-level numerical accuracy using FP16 data in the rank_k kernel?", then I do not think this a possibility other than potentially under highly-restrictive problem sizes and numerical ranges of input data.
However, if the question is, instead, intended to be interpreted as "Can I use FP16 Tensor Core instructions that accumulate into FP32 with the rank_k kernel?", then the answer is "yes."
You would want to change the InstructionShape
to be <16, 8, 8>
, and set an appropriate thread block shape (see examples in the CUTLASS kernel generation script here). You would want to be sure to set the ElementAccumulator
template parameter of the kernel and epilogue to be float
. For a full GEMM instantiation similar to this, you can see the kernel generated under build/tools/library/generated/gemm/cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_nn_align8.cu
after performing the default CMake configuration for CUTLASS described here.
You will still need to make some of the changes that I describe above to get this to work properly, but should not need to follow all of the instructions described above. I believe that all that should be needed are steps 1 and 2 (but with appropriate changes to fit the use of SM75 Tensor Cores, in particular, noting that cutlass::arch::OpClassTensorOp
should continue to be used). Step 3 should not be needed, as you should still be able to use DefaultEpilogueTensorOpBlas3
(because you will be using operation class TensorOp
).
I have not fully tried out the suggestion above, so I may have missed a step; please feel free to post a follow up if you have trouble with what I've described.
Thanks much @jackkosaian! This is extremely helpful! Your second interpretion is what I meant to ask. And your answer is exactly what I need. I'm giving it a try and will let you know how it works.
@jackkosaian , I confirm your answer works well on fp16->fp32 syrk with sm75 GPU. Feel free to close the issue. Thanks again for your help.
What is your question? If I understand correctly, rank_k kernels are by default for sm_80 and newer devices. And it seems I cannot run any rank_k opterations with cutlass_profiler on a T4 device. Is it possilbe for me to run rank_k operations on sm_75 devices? If so, could someone kindly suggest how? Thanks in advance!