NVIDIA / cutlass

CUDA Templates for Linear Algebra Subroutines
Other
5.53k stars 943 forks source link

[BUG] CUTLASS 3.6 profiler doesn't read Instantiation Level that we pass in (Hopper SM90) #1881

Open tonyjie opened 3 days ago

tonyjie commented 3 days ago

Describe the bug 1st bug: 4-digit CUTLASS_LIBRARY_INSTANTIATION_LEVEL is not used. Here it said that CUTLASS 3.6 profiler can use an additional flag, CUTLASS_LIBRARY_INSTANTIATION_LEVEL , to instantiate all possible combinations. It also said that "The CUTLASS profiler employs a four-digit integer level (global instantiation level) mechanism to manage the generation of kernel configurations.".

However, I found that this 4-digit CUTLASS_LIBRARY_INSTANTIATION_LEVEL is not used when compiling, unless it is set to max.

The process of reading CUTLASS_LIBRARY_INSTANTIATION_LEVEL is as follows:

2nd bug: when I manually modify the default level in the generator.py, I get compile error with warp specialization kernels (it says stage < 2).

The error message when make cutlass_profiler is as follows:

~/cutlass/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp(135): error: static assertion failed with "Specialization requires Stages set to value 2 or more."
    static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more.");
    ^
          detected during instantiation of class "cutlass::gemm::collective::CollectiveMma<cutlass::gemm::MainloopSm90TmaGmmaWarpSpecialized<Stages, ClusterShape, KernelSchedule>, TileShape_, ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_, GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_, GmemTiledCopyB_, SmemLayoutAtomB_, SmemCopyAtomB_, TransformB_> [with Stages=1, ClusterShape=cute::tuple<cute::_4, cute::_4, cute::_1>, KernelSchedule=cutlass::gemm::KernelTmaWarpSpecializedCooperative, TileShape_=cute::tuple<cute::_128, cute::_128, cute::_256>, ElementA_=cutlass::half_t, StrideA_=cute::tuple<cute::C<1>, int64_t, int64_t>, ElementB_=cutlass::half_t, StrideB_=cute::tuple<cute::C<1>, int64_t, int64_t>, TiledMma_=cute::TiledMMA<cute::MMA_Atom<cute::SM90::GMMA::MMA_64x128x16_F32F16F16_SS<cute::SM90::GMMA::Major::MN, cute::SM90::GMMA::Major::MN, cute::SM90::GMMA::ScaleIn::One, cute::SM90::GMMA::ScaleIn::One>>, cute::Layout<cute::tuple<cute::_2, cute::_1, cute::_1>, cute::tuple<cute::_1, cute::_0, cute::C<0>>>, cute::tuple<cute::Underscore, cute::Underscore, cute::Underscore>>, GmemTiledCopyA_=cute::SM90_TMA_LOAD_MULTICAST, SmemLayoutAtomA_=cute::ComposedLayout<cute::Swizzle<3, 4, 3>, cute::smem_ptr_flag_bits<16>, cute::Layout<cute::tuple<cute::_64, cute::_8>, cute::tuple<cute::_1, cute::C<64>>>>, SmemCopyAtomA_=void, TransformA_=cute::identity, GmemTiledCopyB_=cute::SM90_TMA_LOAD_MULTICAST, SmemLayoutAtomB_=cute::ComposedLayout<cute::Swizzle<3, 4, 3>, cute::smem_ptr_flag_bits<16>, cute::Layout<cute::tuple<cute::_64, cute::_8>, cute::tuple<cute::_1, cute::C<64>>>>, SmemCopyAtomB_=void, TransformB_=cute::identity]" at line 65 of ~i/cutlass/build/tools/library/generated/gemm/90/f16_s64x64x16gemm_f16/cutlass3x_sm90_tensorop_s64x64x16gemm_f16_f16_f32_f16_f16_128x128x256_4x4x1_0_ntn_align8_stream_k_warpspecialized_cooperative_epi_tma.cu

Steps/Code to reproduce bug

This compile error also occurs when building FP8 kernels.

Additional context Also, when simply set CUTLASS_LIBRARY_INSTANTIATION_LEVEL to max, we will also get compile error when building lots of rare shapes like this

hwu36 commented 3 days ago

@alihassanijr , could you please help with this?

alihassanijr commented 3 days ago

Thanks for the detailed description @tonyjie .

Regarding 1, you are correct, CUTLASS_LIBRARY_INSTANTIATION_LEVEL is limited to pre-defined string levels like "default" and "max". I don't exactly remember why that was, some of it was to preserve previous behavior, but if we want to directly expose the numeric levels to users, we can do that. I'll have to defer to @hwu36 to make that call.

Regarding 2, unfortunately that looks like a much deeper issue, which is that the heuristics set up in the refactored generator don't catch all "invalid" kernels. I'll try and see if I can build with the "max" level and add the necessary heuristics. I'll update this issue soon.