NVIDIA / cutlass

CUDA Templates for Linear Algebra Subroutines
Other
5.43k stars 916 forks source link

[QST] Does simt kernel support gather-gemm-scatter fusion? #675

Closed umiswing closed 1 year ago

umiswing commented 1 year ago

Hello! I write a custom simt kernel to do gather-gemm-scatter fusion. The profiler picks the kernel settings. But I find it will give the wrong result for gather-gemm-scatter. Does the simt kernel support such fusion? Or do I misuse the kernel? I run the code on a 3060 GPU and the following are the code pieces relevant:

void cutlass_simt_dgemm_128x32_8x2_nn_align1(const GPUContext& dev_ctx,
                                            const double* const a,
                                            const double* const b,
                                            const double* const c,
                                            double* const d,
                                            const int m,
                                            const int n,
                                            const int k,
                                            const int32_t* a_indices,
                                            const int32_t* c_d_indices,
                                            double const alpha,
                                            double const beta) {
  ///////////////////////////////////////////////////////////////////////////////////////////////////

  // The code section below describes datatype for input, output matrices and
  // computation between elements in input matrices.

  // This code section describes whether you want to use tensor cores or regular
  // SIMT cores on GPU SM
  using MMAOp = cutlass::arch::OpClassSimt;

  // This code section describes CUDA SM architecture number
  using SmArch = cutlass::arch::Sm50;

  // This code section describes how threadblocks are scheduled on GPU
  using SwizzleThreadBlock =
      cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>;  // <- ??

  // Define the epilogue operation as LinearCombination. This is approximately
  // equal to
  //
  //    d_ij = alpha * sum_k(a_ik * b_kj) + c_ij
  //
  using EpilogueOp =
      cutlass::epilogue::thread::LinearCombination<double,  // <- data type of
                                                            // output matrix
                                                   1,
                                                   double,
                                                   double>;

  // Number of pipelines you want to use
  // Ampere -> 4/5
  // Turing -> 2

  using Gemm =
      cutlass::gemm::device::GemmUniversal<double,
                                           cutlass::layout::RowMajor,
                                           double,
                                           cutlass::layout::RowMajor,
                                           double,
                                           cutlass::layout::RowMajor,
                                           double,
                                           MMAOp,
                                           SmArch,
                                           cutlass::gemm::GemmShape<128, 32, 8>,
                                           cutlass::gemm::GemmShape<64, 32, 8>,
                                           cutlass::gemm::GemmShape<1, 1, 1>,
                                           EpilogueOp,
                                           SwizzleThreadBlock,
                                           2,
                                           1,
                                           1,
                                           cutlass::arch::OpMultiplyAdd,
                                           cutlass::ComplexTransform::kNone,
                                           cutlass::ComplexTransform::kNone,
                                           true,  /*GatherA*/
                                           false, /*GatherB*/
                                           true   /*ScatterD*/
                                           >;

  // ================================================================================
  // Initialization setup

  // Create a tuple of problem size for matrix multiplication
  cutlass::gemm::GemmCoord problem_size_real({m, n, k});

  // Split K dimension into 1 partitions
  int split_k_slices = 1;

  // Create a tuple of gemm kernel arguments. This is later passed as arguments
  // to launch instantiated CUTLASS kernel
  typename Gemm::Arguments arguments{
      cutlass::gemm::GemmUniversalMode::kGemm,
      problem_size_real,  // <- problem size of matrix multiplication
      split_k_slices,     // <- k-dimension split factor
      {alpha, beta},      // <- alpha, beta
      a,                  // <- reference to matrix A on device
      b,                  // <- reference to matrix B on device
      c,                  // <- reference to matrix C on device
      d,                  // <- reference to matrix D on device
      cutlass::layout::RowMajor().capacity(problem_size_real.mk()),
      cutlass::layout::RowMajor().capacity(problem_size_real.kn()),
      cutlass::layout::RowMajor().capacity(problem_size_real.mn()),
      cutlass::layout::RowMajor().capacity(problem_size_real.mn()),
      problem_size_real.k(),
      problem_size_real.n(),
      problem_size_real.n(),
      problem_size_real.n(),
      a_indices,     // <- pointer to index vector to gather A on device
      nullptr,       // <- pointer to index vector to gather B on device
      c_d_indices};  // <- pointer to index vector to scatter D on device

  // Using the arguments, query for extra workspace required for matrix
  // multiplication computation
  size_t workspace_size = Gemm::get_workspace_size(arguments);

  // Allocate workspace memory
  cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);

  // Instantiate CUTLASS kernel depending on templates
  Gemm gemm_op;

  // Check the problem size is supported or not
  cutlass::Status status = gemm_op.can_implement(arguments);
  CUTLASS_CHECK(status);

  // Initialize CUTLASS kernel with arguments and workspace pointer
  status = gemm_op.initialize(arguments, workspace.get());
  CUTLASS_CHECK(status);

  // CPU reference calculation

  status = gemm_op(dev_ctx.stream());
}
hwu36 commented 1 year ago

We did not add gather to the iterators used by simt. It is pretty straightforward. Just repeat what we added for the tensor core iterators. Welcome the community contributions on this.

umiswing commented 1 year ago

Thank you! :)