NVIDIA / cutlass

CUDA Templates for Linear Algebra Subroutines
Other
4.83k stars 835 forks source link

[QST/BUG] why cute kernel transfers so much data between L2 and gmen than cublas kernel #1556

Open irasin opened 1 month ago

irasin commented 1 month ago

What is your question?

I am learning to use cute to build a hgemm kernel. Tested on A10 GPU, the cute kernel is good with small problem size such as m/n/k = 4096, but I found it's much slower than cublas kernel with problem size m/n/k=16384/16384/16384 as below

image

Here is the profile result from ncu for problem m/n/k=16384/16384/16384. image

And I found the biggest difference between my cute kernel and cublas kernel is the memory chart

cublas kernel image

my cute kernel image

I was wondering why my cute kernel has so much gmem->L2 and L2->shared data movement compared to cublas kernel. And how should I modify the cute kernel to improve performance for big problem size.

Here is my cute kernel


namespace config {

using namespace cute;

template <int BM, int BN, int BK, int Stage, int SWIZZLE>
struct HGemmConfig {
    using ADataType = half_t;
    using BDataType = half_t;
    using CDataType = half_t;

    using TileShape = Shape<Int<BM>, Int<BN>, Int<BK>>;
    using TiledMma  = TiledMMA<MMA_Atom<SM80_16x8x16_F16F16F16F16_TN>, Layout<Shape<_2, _2, _1>>, Tile<_32, _32, _16>>;

    static constexpr int ThreadCount = size(TiledMma{});

    // A: row-major (m, k)
    using SmemLayoutAtomA = decltype(composition(Swizzle<2, 3, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{}));
    using SmemLayoutA     = decltype(tile_to_shape(SmemLayoutAtomA{}, Shape<Int<BM>, Int<BK>, Int<Stage>>{}));
    using G2STiledCopyA   = decltype(make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, ADataType>{}, Layout<Shape<_32, _4>, Stride<_4, _1>>{}, Layout<Shape<_1, _8>>{}));
    using S2RCopyAtomA    = Copy_Atom<SM75_U32x4_LDSM_N, ADataType>;
    using S2RTiledCopyA   = decltype(make_tiled_copy_A(S2RCopyAtomA{}, TiledMma{}));

    // B: row-major (k, n)
    using SmemLayoutAtomB = decltype(composition(Swizzle<2, 3, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{}));
    using SmemLayoutB     = decltype(tile_to_shape(SmemLayoutAtomB{}, Shape<Int<BN>, Int<BK>, Int<Stage>>{}));
    using G2STiledCopyB   = decltype(make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, BDataType>{}, Layout<Shape<_16, _8>, Stride<_1, _16>>{}, Layout<Shape<_8, _1>>{}));
    using S2RCopyAtomB    = Copy_Atom<SM75_U16x8_LDSM_T, BDataType>;
    using S2RTiledCopyB   = decltype(make_tiled_copy_B(S2RCopyAtomB{}, TiledMma{}));

    // C: row-major (m, n)
    using SmemLayoutAtomC = decltype(composition(Swizzle<2, 3, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{}));
    using SmemLayoutC     = decltype(tile_to_shape(SmemLayoutAtomC{}, Shape<Int<BM>, Int<BN>>{}));
    using R2SCopyAtomC    = Copy_Atom<UniversalCopy<int>, CDataType>;
    using R2STiledCopyC   = decltype(make_tiled_copy_C(R2SCopyAtomC{}, TiledMma{}));
    using S2GTiledCopyC   = decltype(make_tiled_copy(Copy_Atom<UniversalCopy<cute::uint128_t>, CDataType>{}, Layout<Shape<_32, _4>, Stride<_4, _1>>{}, Layout<Shape<_1, _8>>{}));

    static constexpr int SmemSizeA = cosize_v<SmemLayoutA> * sizeof(ADataType);
    static constexpr int SmemSizeB = cosize_v<SmemLayoutB> * sizeof(BDataType);
    static constexpr int SmemSizeC = cosize_v<SmemLayoutC> * sizeof(CDataType);
    static constexpr int SmemSize  = cute::max(SmemSizeA + SmemSizeB, SmemSizeC);

    static constexpr int kStage   = Stage;
    static constexpr int kSwizzle = SWIZZLE;
};

}  // namespace config

template <typename Config>
__global__ void CUTEMultiStageKernel(const half* A, const half* B, half* C, int M, int N, int K) {
    using namespace cute;

    constexpr int SWIZZLE = Config::kSwizzle;
    int           tid     = threadIdx.x;
    int           bx      = blockIdx.x / SWIZZLE;
    int           by      = blockIdx.y * SWIZZLE + blockIdx.x % SWIZZLE;

    Tensor mA = make_tensor(make_gmem_ptr(A), make_shape(M, K), make_stride(K, Int<1>{}));  // (M,K):(K,1)
    Tensor mB = make_tensor(make_gmem_ptr(B), make_shape(N, K), make_stride(Int<1>{}, N));  // (N,K):(1,N)
    Tensor mC = make_tensor(make_gmem_ptr(C), make_shape(M, N), make_stride(N, Int<1>{}));  // (M,N):(N,1)

    auto   cta_tiler = typename Config::TileShape{};                             // (BM, BN, BK)
    auto   cta_coord = make_coord(by, bx, _);                                    // (m, n, k)
    Tensor gA        = local_tile(mA, cta_tiler, cta_coord, Step<_1, X, _1>{});  // (BM, BK, k_tile)
    Tensor gB        = local_tile(mB, cta_tiler, cta_coord, Step<X, _1, _1>{});  // (BN, BK, k_tile)
    Tensor gC        = local_tile(mC, cta_tiler, cta_coord, Step<_1, _1, X>{});  // (BM, BN)

    extern __shared__ uint8_t raw_smem[];
    half*                     p_sA = reinterpret_cast<half*>(raw_smem);
    half*                     p_sB = reinterpret_cast<half*>(raw_smem + Config::SmemSizeA);
    Tensor                    sA   = make_tensor(make_smem_ptr(p_sA), typename Config::SmemLayoutA{});  // (BM, BK, Stage)
    Tensor                    sB   = make_tensor(make_smem_ptr(p_sB), typename Config::SmemLayoutB{});  // (BN, BK, Stage)

    typename Config::TiledMma tiled_mma{};
    ThrMMA                    thr_mma = tiled_mma.get_slice(tid);
    Tensor                    tCrA    = thr_mma.partition_fragment_A(gA(_, _, 0));
    Tensor                    tCrB    = thr_mma.partition_fragment_B(gB(_, _, 0));
    Tensor                    tCrC    = thr_mma.partition_fragment_C(gC);
    clear(tCrC);

    typename Config::G2STiledCopyA tiled_g2s_A{};
    ThrCopy                        thr_g2s_A = tiled_g2s_A.get_slice(tid);
    Tensor                         tAgA      = thr_g2s_A.partition_S(gA);
    Tensor                         tAsA      = thr_g2s_A.partition_D(sA);

    typename Config::G2STiledCopyB tiled_g2s_B{};
    ThrCopy                        thr_g2s_B = tiled_g2s_B.get_slice(tid);
    Tensor                         tBgB      = thr_g2s_B.partition_S(gB);
    Tensor                         tBsB      = thr_g2s_B.partition_D(sB);

    typename Config::S2RTiledCopyA tiled_s2r_A{};
    ThrCopy                        thr_s2r_A   = tiled_s2r_A.get_slice(tid);
    Tensor                         tCsA        = thr_s2r_A.partition_S(sA);
    Tensor                         tCrA_retile = thr_s2r_A.retile_D(tCrA);

    typename Config::S2RTiledCopyB tiled_s2r_B{};
    ThrCopy                        thr_s2r_B   = tiled_s2r_B.get_slice(tid);
    Tensor                         tCsB        = thr_s2r_B.partition_S(sB);
    Tensor                         tCrB_retile = thr_s2r_B.retile_D(tCrB);

    constexpr int STAGE = Config::kStage;
    for (int i = 0; i < STAGE - 1; ++i) {
        cute::copy(tiled_g2s_A, tAgA(_, _, _, i), tAsA(_, _, _, i));
        cute::copy(tiled_g2s_B, tBgB(_, _, _, i), tBsB(_, _, _, i));
        cp_async_fence();
    }

    int k_tile_count    = size<3>(tAgA);
    int k_tile_next     = STAGE - 1;
    int smem_pipe_write = STAGE - 1;
    int smem_pipe_read  = 0;

    cp_async_wait<STAGE - 2>();
    __syncthreads();

    int ik = 0;
    cute::copy(tiled_s2r_A, tCsA(_, _, ik, smem_pipe_read), tCrA_retile(_, _, ik));
    cute::copy(tiled_s2r_B, tCsB(_, _, ik, smem_pipe_read), tCrB_retile(_, _, ik));

    constexpr int CHUNK_K = size<2>(tCrA);

    for (int k_tile = 0; k_tile < k_tile_count; ++k_tile) {
        for (int ik = 0; ik < CHUNK_K; ++ik) {
            if (ik == 0) {
                if (k_tile_next < k_tile_count) {
                    cute::copy(tiled_g2s_A, tAgA(_, _, _, k_tile_next), tAsA(_, _, _, smem_pipe_write));
                    cute::copy(tiled_g2s_B, tBgB(_, _, _, k_tile_next), tBsB(_, _, _, smem_pipe_write));

                    ++k_tile_next;
                    smem_pipe_write = (smem_pipe_write + 1) % STAGE;
                }
                cp_async_fence();
            }

            cute::gemm(tiled_mma, tCrA(_, _, ik), tCrB(_, _, ik), tCrC);

            if (ik == CHUNK_K - 1) {
                cp_async_wait<STAGE - 2>();
                __syncthreads();
                smem_pipe_read = (smem_pipe_read + 1) % STAGE;
            }

            int ik_next = (ik + 1) % CHUNK_K;

            cute::copy(tiled_s2r_A, tCsA(_, _, ik_next, smem_pipe_read), tCrA_retile(_, _, ik_next));
            cute::copy(tiled_s2r_B, tCsB(_, _, ik_next, smem_pipe_read), tCrB_retile(_, _, ik_next));
        }
    }

    half*  p_sC = reinterpret_cast<half*>(raw_smem);
    Tensor sC   = make_tensor(make_smem_ptr(p_sC), typename Config::SmemLayoutC{});

    typename Config::R2STiledCopyC tiled_r2s_C{};
    ThrCopy                        thr_r2s_C   = tiled_r2s_C.get_slice(tid);
    Tensor                         tCrC_retile = thr_r2s_C.retile_S(tCrC);
    Tensor                         tCsC        = thr_r2s_C.partition_D(sC);
    __syncthreads();
    cute::copy(tiled_r2s_C, tCrC_retile, tCsC);

    typename Config::S2GTiledCopyC tiled_s2g_C{};
    ThrCopy                        thr_s2g_C = tiled_s2g_C.get_slice(tid);
    Tensor                         tDsC      = thr_s2g_C.partition_S(sC);
    Tensor                         tDgC      = thr_s2g_C.partition_D(gC);
    __syncthreads();
    cute::copy(tiled_s2g_C, tDsC, tDgC);
}

void CUTEMultiStage(half* A, half* B, half* C, int M, int N, int K) {
    constexpr int BM      = 128;
    constexpr int BN      = 128;
    constexpr int BK      = 32;
    constexpr int STAGE   = 3;
    constexpr int SWIZZLE = 4;

    using hgemm_config = config::HGemmConfig<BM, BN, BK, STAGE, SWIZZLE>;

    constexpr int smem_max_size = hgemm_config::SmemSize;

    static bool initialized = false;
    if (!initialized) {
        PD_CUDA_CHECK(cudaFuncSetAttribute(CUTEMultiStageKernel<hgemm_config>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_max_size));
        initialized = true;
    }

    dim3 block(hgemm_config::ThreadCount);
    dim3 grid(PD_ROUND_DIV(N, BN) * SWIZZLE, PD_ROUND_DIV(PD_ROUND_DIV(M, BM), SWIZZLE));

    CUTEMultiStageKernel<hgemm_config><<<grid, block, smem_max_size>>>(A, B, C, M, N, K);
}
thakkarV commented 1 month ago

you are likely thrashing the L2 locality and by not doing any block ID remapping / swizzling

irasin commented 1 month ago

Hi, @thakkarV, thanks for your reply The performance is still bad when I removed the thread block swizzle.

irasin commented 1 month ago

I guess this abnormal data reading must be related to the cp.async of g2s part , but I don't know the specific reason of the problem.

And I found if I change

constexpr int BN      = 128;

to

constexpr int BN      = 256;

, the abnormal data reading will disappear, and here's the result of ncu. image

The data reading from gmem to L2 cache is still more than cublas version, but I think it should be acceptable here.

And the throughput result of different large input sizes looks good too. image

However, I have no idea how the block tile shape here will affect the L2 cache locality. Is there something wrong in my code or some bugs here? can anyone help?

ccecka commented 1 month ago

By "block ID remapping / swizzling", Vijay meant the Tile Schedulers that are part of CUTLASS and cuBLAS, not the swizzling on the data or threadblocks. CUTLASS and cuBLAS have many Tile Schedulers including Split-K and Stream-K and other skews on the block ID assignment to work tile. Many of these strategies are targeted at increasing L2 locality. So cuBLAS uses a heuristic to choose the best Tile Scheduler for your problem, which certainly changes with problem size.

For more information, I recommend the (full version of our) Stream-K paper: https://arxiv.org/pdf/2301.03598

ssiu commented 1 month ago

Hi @ccecka can I ask a basic question? I skimmed through the stream-K paper and my (rudimentary) understanding is that it beats cublas by minimizing the tail effect better. In this situation where the GEMM size is large I am assuming there are lots of threadblocks so would we expect the performance of cuBLAS be more or less the same as stream-K?

In other words does streak-K outperform cuBLAS for GEMM with a large number of blocks?

Thanks!

irasin commented 2 weeks ago

Any update? Can I think it's the expected behavior for this cute kernel?