NVIDIA / cutlass

CUDA Templates for Linear Algebra Subroutines
Other
5.76k stars 987 forks source link

[QST]error: too few arguments for class template "cutlass::epilogue::collective::DefaultEpilogue" #1589

Open sleepwalker2017 opened 5 months ago

sleepwalker2017 commented 5 months ago

I run the example in the quick start guide.

My GPU is A30, the command is nvcc 01_gemm_3.0.cu -arch=sm_80 It complains errors:

01_gemm_3.0.cu(51): error: too few arguments for class template "cutlass::epilogue::collective::DefaultEpilogue"
        cutlass::epilogue::thread::LinearCombination<ElementC, 1, ElementAccumulator, ElementAccumulator>>;
                                                                                                        ^

01_gemm_3.0.cu(61): error: incomplete type is not allowed
    Gemm gemm_op;
         ^

01_gemm_3.0.cu(79): error: incomplete type is not allowed
    cutlass::DeviceAllocation<typename Gemm::ElementA> block_A;
                                       ^

01_gemm_3.0.cu(80): error: incomplete type is not allowed
    cutlass::DeviceAllocation<typename Gemm::ElementB> block_B;
                                       ^

01_gemm_3.0.cu(81): error: incomplete type is not allowed
    cutlass::DeviceAllocation<typename Gemm::ElementC> block_C;
                                       ^

01_gemm_3.0.cu(82): error: incomplete type is not allowed
    cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_D;
                                       ^

01_gemm_3.0.cu(84): error: incomplete type is not allowed
    using StrideA = typename Gemm::GemmKernel::StrideA;
                             ^

01_gemm_3.0.cu(85): error: incomplete type is not allowed
    using StrideB = typename Gemm::GemmKernel::StrideB;
                             ^

01_gemm_3.0.cu(86): error: incomplete type is not allowed
    using StrideC = typename Gemm::GemmKernel::StrideC;
                             ^

01_gemm_3.0.cu(87): error: incomplete type is not allowed
    using StrideD = typename Gemm::GemmKernel::StrideD;
                             ^

01_gemm_3.0.cu(94): error: no instance of overloaded function "cutlass::make_cute_packed_stride" matches the argument list
            argument types are: (<error-type>, cute::tuple<int, int, cute::C<1>>)
    stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, Int<1>{}));
               ^

01_gemm_3.0.cu(95): error: no instance of overloaded function "cutlass::make_cute_packed_stride" matches the argument list
            argument types are: (<error-type>, cute::tuple<int, int, cute::C<1>>)
    stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, Int<1>{}));
               ^

01_gemm_3.0.cu(96): error: no instance of overloaded function "cutlass::make_cute_packed_stride" matches the argument list
            argument types are: (<error-type>, cute::tuple<int, int, cute::C<1>>)
    stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, Int<1>{}));
               ^

01_gemm_3.0.cu(97): error: no instance of overloaded function "cutlass::make_cute_packed_stride" matches the argument list
            argument types are: (<error-type>, cute::tuple<int, int, cute::C<1>>)
    stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, Int<1>{}));
               ^

14 errors detected in the compilation of "01_gemm_3.0.cu".

Here are my codes:

#include "cutlass/cutlass.h"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"

#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"

using namespace cute;

int main(int argc, char const **args) {

  // A matrix configuration
  using         ElementA    = cutlass::half_t;                                // Element type for A matrix operand
  using         LayoutA     = cutlass::layout::RowMajor;                      // Layout type for A matrix operand
  constexpr int AlignmentA  = 128 / cutlass::sizeof_bits<ElementA>::value;    // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)

  // B matrix configuration
  using         ElementB    = cutlass::half_t;                                // Element type for B matrix operand
  using         LayoutB     = cutlass::layout::ColumnMajor;                   // Layout type for B matrix operand
  constexpr int AlignmentB  = 128 / cutlass::sizeof_bits<ElementB>::value;    // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)

  // C/D matrix configuration
  using         ElementC    = cutlass::half_t;                                // Element type for C and D matrix operands
  using         LayoutC     = cutlass::layout::ColumnMajor;                   // Layout type for C and D matrix operands

  // Core kernel configurations
  using ElementAccumulator  = float;                                          // Element type for internal accumulation
  using ArchTag             = cutlass::arch::Sm80;                            // Tag indicating the minimum SM that supports the intended feature
  using OperatorClass       = cutlass::arch::OpClassTensorOp;                 // Operator class tag
  using TilesShape          = Shape<_128,_128,_64>;                           // Threadblock-level tile size
  using ClusterShape        = Shape<_1,_2,_1>;                                // Shape of the threadblocks in a cluster
  using StageCountType = cutlass::gemm::collective::StageCountAuto;           // Stage count maximized based on the tile size
  using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;       // Kernel to launch based on the default setting in the Collective Builder

  using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
      ArchTag, OperatorClass,
      ElementA, LayoutA, AlignmentA,
      ElementB, LayoutB, AlignmentB,
      ElementAccumulator,
      TilesShape, ClusterShape,
      cutlass::gemm::collective::StageCountAuto,
      cutlass::gemm::collective::KernelScheduleAuto
    >::CollectiveOp;
  using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue<
      cutlass::gemm::TagToStrideC_t<LayoutC>,
      cutlass::gemm::TagToStrideC_t<LayoutC>,
      cutlass::epilogue::thread::LinearCombination<ElementC, 1, ElementAccumulator, ElementAccumulator>>;

  using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
      Shape<int,int,int>, // Indicates ProblemShape
      CollectiveMainloop,
      CollectiveEpilogue
  >;

  using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

  Gemm gemm_op;
  cutlass::Status status;

  //
  // Define the problem size
  //

  int M = 512;
  int N = 256;
  int K = 128;

  float alpha = 1.25f;
  float beta = -1.25f;

  //
  // Allocate device memory
  //

  cutlass::DeviceAllocation<typename Gemm::ElementA> block_A;
  cutlass::DeviceAllocation<typename Gemm::ElementB> block_B;
  cutlass::DeviceAllocation<typename Gemm::ElementC> block_C;
  cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_D;

  using StrideA = typename Gemm::GemmKernel::StrideA;
  using StrideB = typename Gemm::GemmKernel::StrideB;
  using StrideC = typename Gemm::GemmKernel::StrideC;
  using StrideD = typename Gemm::GemmKernel::StrideD;

  StrideA stride_A;
  StrideB stride_B;
  StrideC stride_C;
  StrideD stride_D;

  stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, Int<1>{}));
  stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, Int<1>{}));
  stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, Int<1>{}));
  stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, Int<1>{}));

  block_A.reset(M * K);
  block_B.reset(K * N);
  block_C.reset(M * N);
  block_D.reset(M * N);

  //
  // Launch GEMM on the device
  //

  status = gemm_op({
    cutlass::gemm::GemmUniversalMode::kGemm,
    {M, N, K},
    block_A.get(),
    stride_A,
    block_B.get(),
    stride_B,
    {block_C.get(), stride_C, block_D.get(), stride_D, {alpha, beta}}
  });

  if (status != cutlass::Status::kSuccess) {
    return -1;
  }

  return 0;
}
github-actions[bot] commented 4 months ago

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.

github-actions[bot] commented 1 month ago

This issue has been labeled inactive-90d due to no recent activity in the past 90 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed.

wbrickner commented 1 week ago

yeah I encounter very similar issues. haven't been able to make any cutlass code build. I think we are doing something simple wrong, I am rusty with C build process.

horrorChen commented 1 week ago

I think the code in the quickstart.md is out-of-date. DefaultEpilogue calls for another template argument EpilogueSchedule_, but when I add the argument with cutlass::epilogue::collective::EpilogueScheduleAuto and its head file cutlass/epilogue/collective/collective_builder.hpp, it goes wrong with another failure:

mixed_gemm.cu(110): error: no instance of overloaded function "cutlass::gemm::device::GemmUniversalAdapter<GemmKernel_, std::enable_if_t<cutlass::gemm::detail::IsCutlass3GemmKernel<GemmKernel_, void>::value, void>>::operator() [with GemmKernel_=cutlass::gemm::kernel::GemmUniversal<cute::tuple<int32_t, int32_t, int32_t>, cutlass::gemm::collective::CollectiveMma<cutlass::gemm::MainloopSm90TmaGmmaWarpSpecialized<7, cute::tuple<cute::_1, cute::_2, cute::_1>, cutlass::gemm::KernelTmaWarpSpecializedCooperative>, cute::tuple<cute::_128, cute::_128, cute::_64>, cutlass::half_t, cute::tuple<int64_t, cute::C<1>, int64_t>, cutlass::half_t, cute::tuple<int64_t, cute::C<1>, int64_t>, cute::TiledMMA<cute::MMA_Atom<cute::SM90::GMMA::MMA_64x128x16_F32F16F16_SS<cute::SM90::GMMA::Major::K, cute::SM90::GMMA::Major::K, cute::SM90::GMMA::ScaleIn::One, cute::SM90::GMMA::ScaleIn::One>>, cute::Layout<cute::tuple<cute::_2, cute::_1, cute::_1>, cute::tuple<cute::_1, cute::_0, cute::_0>>, cute::tuple<cute::Underscore, cute::Underscore, cute::Underscore>>, cute::SM90_TMA_LOAD_MULTICAST, cute::ComposedLayout<cute::Swizzle<3, 4, 3>, cute::smem_ptr_flag_bits<16>, cute::Layout<cute::tuple<cute::_8, cute::_64>, cute::tuple<cute::_64, cute::_1>>>, void, cute::identity, cute::SM90_TMA_LOAD, cute::ComposedLayout<cute::Swizzle<3, 4, 3>, cute::smem_ptr_flag_bits<16>, cute::Layout<cute::tuple<cute::_8, cute::_64>, cute::tuple<cute::_64, cute::_1>>>, void, cute::identity>, cutlass::epilogue::collective::DefaultEpilogue<cute::tuple<cute::C<1>, int64_t, int64_t>, cute::tuple<cute::C<1>, int64_t, int64_t>, cutlass::epilogue::thread::LinearCombination<cutlass::half_t, 1, float, float, cutlass::epilogue::thread::ScaleType::Default, cutlass::FloatRoundStyle::round_to_nearest, cutlass::half_t>, cutlass::epilogue::collective::EpilogueScheduleAuto>, void, void>]" matches the argument list
            argument types are: ({...})
            object type is: Gemm
    status = gemm_op({