ROCm / rocWMMA

rocWMMA
https://rocm.docs.amd.com/projects/rocWMMA/
MIT License
86 stars 25 forks source link

[Issue]: gemm tests failed in ROCM 6.2 #444

Open yiakwy-xpu-ml-framework-team opened 3 days ago

yiakwy-xpu-ml-framework-team commented 3 days ago

Problem Description

I am investigating usage of instruction v_mfma_f32_16x16x16_f16 and nvidia equivalent warp-level mma (swizzle SRAM memory + ldmatrix registers + mma over registers, for Ampere arch style computation) with multiple blocks.

And I found our gemm tests failed in ROCM6.2

截屏2024-09-26 18 58 15

Operating System

Ubuntu 22.04

CPU

AMD EPYC 9534 64-Core Processor

GPU

AMD Instinct MI300X

ROCm Version

ROCm 6.2.0

ROCm Component

rocWMMA

Steps to Reproduce

buld rocWMMA and trigger tests of gemm_PGR0_LB0_MP0_MB_NC_ad_hoc-validate (multiple blocks, no memory optimization)

(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support

No response

Additional Information

No response

cgmillette commented 2 days ago

Hello @yiakwy-xpu-ml-framework-team, thanks for reaching out! I've created an internal ticket for investigation. Will report back with updates as they become available.

yiakwy-xpu-ml-framework-team commented 21 hours ago

@cgmillette I am not sure if you are the right person for this question, but I hope you can give me some feed back on this three kernels if you are interested in, since I found no direct assembly usage for our wmma instruction.

Luckily I just decode the layout of wmma in AMD chip for CDNA3 arch. But there is still some minor problem for direct use of the instruction.

I proposed three functions

__builtin_amdgcn_mfma_f32_16x16x16f16

Memory transaction estimation for this instruct

    // total loads : 16 x 16 x 2 = 512 bytes
    // a memory transaction : 4 (32 bits) x 64 = 256 bytes
    // phases : 2 = 512 / 256 (i.e. we need at least two 64 threads warps to load all data from smem into registers VGRP )
    // elemens per threads : 16 x 16 / 64 = 4

We can apply the same analysis to the other wmma instruction.

As for memroy layout, I reproduced the single warp result in the scripts with fully explanation.

Test Codes

#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>

#include <iostream>

using half = __half;
using float16_t = _Float16;

#include <rocwmma/rocwmma.hpp>

#define M 16
#define N 16
#define K 16

__global__ void v_mfma_intrinsics_test(float* output, unsigned int lda, unsigned int ldb) {
    int lane_id = threadIdx.x % rocwmma::Constants::AMDGCN_WAVE_SIZE ;

    // coordinates : see ROCM SDK MappingUtils 
    uint global_row = (blockIdx.x * blockDim.x + threadIdx.x) / 64 * M;
    uint global_col = (blockIdx.y * blockDim.y + threadIdx.y)      * N;

    // in this test lda == BLOCK_SIZE_M (M), ldb == BLOCK_SIZE_N (N)

    // NOTE(yiakwy) : I produce frag layout based on research https://github.com/yiakwy-xpu-ml-framework-team/AMD-lab-notes-fork/blob/release/matrix-cores/src/mfma_fp32_16x16x16fp16.cpp
    // a_frag is stored in 2 x VGPRs pair (32 lanes x 2, see CDNA3 ISA), each thread process 4 fp16 elements;
    //   lane     0-15        16-31       32-47      48-63 
    //  Reg\Col   0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15
    // a_frag[0]  x           x           x           x
    // a_frag[1]     *           *           *          *
    // ...
    // a_frag[3]           +           +                       +
    //
    // Here, a_frag[0] is accessed by 16 consecutive threads (T0-T15), each thread cover 4 elements
    //
    //                                                      lane
    //  row\Reg   b_frag[0] b_frag[1] b_frag[2] b_frag[3]
    //    0       x                                         0-15
    //    1                 *
    //    2                           &
    //    3                                     +
    //    0+4     x                                         16-31
    //    1+4               *
    //    2+4                         &
    //    3+4                                   +
    //    0+8     x                                         32-47
    //    1+8               *
    //    2+8                         &
    //    3+8                                   +
    //    0+12    x                                         48-63
    //    1+12              *
    //    2+12                        *
    //    3+12                                  +
    using float16x4 = __attribute__((__vector_size__(4 * sizeof(float16_t)))) float16_t;
    uint a_frag[2] = {0};
    uint b_frag[2] = {0};

    float16x4 *a = reinterpret_cast<float16x4 *>(a_frag);
    float16x4 *b = reinterpret_cast<float16x4 *>(b_frag);

    // the output fragment is stored in 4 x AccVGPRs (see CDNA3 ISA), each thread process 4 fp32 elements
    using floatx4 = __attribute__((__vector_size__(4 * sizeof(float)))) float;
    uint acc_frag[4] = {0};

    floatx4 *d = reinterpret_cast<floatx4 *>(acc_frag);

    unsigned int ele_per_thread = M * N / rocwmma::Constants::AMDGCN_WAVE_SIZE ;

    for (int i=0; i < K; i+= K) {

        // uint a_warp_offset = i + global_row * lda;
        // uint b_warp_offset = i + global_col * ldb;

        // rocwmma::load_matrix_sync(a_frag, a + a_offset, M);
        // rocwmma::load_matrix_sync(b_frag, b + b_offset, K);

        *d = __builtin_amdgcn_mfma_f32_16x16x16f16(*a, *b, *d, 0, 0, 0);
    }

        // Note(yiakwy) : store thread private memory back to global memory is tricky. Suppose output is still major layout , then threads configuration is (4/*x*/, 4/*j*/, 16/*y*/) :
        //
        // lane                   0     1         15     16     17         32  ...    63
        //  Row(x) \ Col (j,y) <0,0> <0,1> ... <0,15> <1, 0> <1, 1> ... <1,15> ... <4,15>
        //   0                   x  
        //   1                   x
        //   2                   x
        //   3                   x
        //
        // x = output[warp_offset + coord2Indx (x, j, y)], coord2Index : (x, j, y) -> ( y + j * 16 ) + x * 64
        for (int j=0; j < ele_per_thread; j++) {
            auto x = ( threadIdx.x / 16 ) % 4;
            auto y = threadIdx.x % 16; // y + j * 16 is the output data lane ID
            auto outIdx = (global_row * N + global_col) + y + j * 16 + x * 64;
            output[outIdx] = (*d)[j];
        }
}

// a new single warp test for mfma instruction
__global__ void v_mfma_asm_test(float* output, unsigned int lda, unsigned int ldb) {
    int lane_id = threadIdx.x % rocwmma::Constants::AMDGCN_WAVE_SIZE ;

    // coordinates : see ROCM SDK MappingUtils 
    uint global_row = (blockIdx.x * blockDim.x + threadIdx.x) / 64 * M;
    uint global_col = (blockIdx.y * blockDim.y + threadIdx.y)      * N;

    using float16x4 = __attribute__((__vector_size__(4 * sizeof(float16_t)))) float16_t;
    uint a_frag[2] = {0};
    uint b_frag[2] = {0};

    float16x4 *a = reinterpret_cast<float16x4 *>(a_frag);
    float16x4 *b = reinterpret_cast<float16x4 *>(b_frag);

    for (int j=0; j < 4; j++) {
        (*a)[j] = (half)1.f;
        (*b)[j] = (half)1.f;
    }

    // the output fragment is stored in 4 x AccVGPRs (see CDNA3 ISA)
    using floatx4 = __attribute__((__vector_size__(4 * sizeof(float)))) float;
    uint acc_frag[4] = {0};

    floatx4 *d = reinterpret_cast<floatx4 *>(acc_frag);

    for (int i=0; i < K; i+= K) {

        // uint a_offset = i + global_row * lda;
        // uint b_offset = i + global_col * ldb;

        // rocwmma::load_matrix_sync(a_frag, a + a_offset, M);
        // rocwmma::load_matrix_sync(b_frag, b + b_offset, K);

       unsigned int ele_per_thread = M * N / rocwmma::Constants::AMDGCN_WAVE_SIZE ;

       // 4 x outer product accumulation
       // *d = __builtin_amdgcn_mfma_f32_16x16x16f16(*a, *b, *d, 0, 0, 0);
        asm volatile("v_mfma_f32_16x16x16_f16 "
                        "%0, "
                        "%1, "
                        "%2, 0;\n"
                        : "=v"(*d)
                        :  "v"(*a), 
                           "v"(*b));

       // store back to global memory
       for (int j=0; j < ele_per_thread; j++) {
            auto x = ( threadIdx.x / 16 ) % 4;
            auto y = threadIdx.x % 16; // y + j * 16 is the output data lane ID
            auto outIdx = (global_row * N + global_col) + y + j * 16 + x * 64;
            output[outIdx] = (*d)[j];
       }
    }
}

__global__ void v_mfma_test(float* output, unsigned int lda, unsigned int ldb)
{
    int lane_id = threadIdx.x % rocwmma::Constants::AMDGCN_WAVE_SIZE ;

    // coordinates : see ROCM SDK MappingUtils 
    uint global_row = (blockIdx.x * blockDim.x + threadIdx.x) / 64 * M;
    uint global_col = (blockIdx.y * blockDim.y + threadIdx.y)      * N;

    rocwmma::fragment<rocwmma::matrix_a, 16, 16, 16, half, rocwmma::row_major> a_frag;
    rocwmma::fragment<rocwmma::matrix_b, 16, 16, 16, half, rocwmma::col_major> b_frag;

    rocwmma::fragment<rocwmma::accumulator, 16, 16, 16, float> acc_frag; // ComputeT

    rocwmma::fill_fragment(acc_frag, 0.0f);

    // total loads : 16 x 16 x 2 = 512 bytes
    // a memory transaction : 4 (32 bits) x 64 = 256 bytes
    // phases : 2 = 512 / 256
    // elemens per threads : 16 x 16 / 64 = 4 

    for (int i = 0; i < K; i+=K) {
        rocwmma::fill_fragment(a_frag, (half)(1 / 16.0));        
        rocwmma::fill_fragment(b_frag, (half)1.0);

        // Matrix multiply - accumulate using MFMA units
        rocwmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag);
    }

    // output is row major
    rocwmma::store_matrix_sync(output + (global_row * N + global_col), acc_frag, M, rocwmma::mem_row_major);

}

int main(int argc, char * argv[]) {
    int device_id = 0;
    hipGetDevice(&device_id);
    int major = 0, minor = 0;
    hipDeviceComputeCapability(&major, &minor, device_id);

    std::cout << "Mjaor: " << major << "," << "Minor: " << minor << std::endl;

    int max_smem_per_sm = 0;
    hipDeviceGetAttribute(
        &max_smem_per_sm, hipDeviceAttribute_t::hipDeviceAttributeMaxSharedMemoryPerMultiprocessor, device_id);

    std::cout << "Max sems per sm : " << max_smem_per_sm << std::endl;

    using DTypeQ = __half;
    const int num_ctas_per_sm = max_smem_per_sm > (16 * 64 * sizeof(DTypeQ) * 16) ? 2 : 1;
    const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm;

    std::cout << "Num ctas per sm : " << num_ctas_per_sm << std::endl;
    std::cout << "Max sems per block : " << max_smem_per_threadblock << std::endl;

    const int num_warps_z = 4;
    const uint32_t max_num_frags_z_smem =
        (max_smem_per_threadblock / (16 * 64 * sizeof(DTypeQ)) ) /
        (2 * num_warps_z);

    std::cout << "max_num_frags_z_smem : " << max_num_frags_z_smem << std::endl;

    // tests mfma

    size_t num_ele = 16 * 16;
    float* output = (float*)malloc(sizeof(float) * num_ele);

    float* output_d = nullptr;
    hipMalloc(&output_d, sizeof(float) * num_ele);

    // 1 x warp (block) test
    // v_mfma_test<<<1, 64>>>(output_d, M/*lda*/, N/*ldb*/); // pass with correct answer
    // v_mfma_intrinsics_test<<<1, 64>>>(output_d, M/*lda*/, N/*ldb*/); // pass with correct answer

    v_mfma_asm_test<<<1, 64>>>(output_d,  M/*lda*/, N/*ldb*/); // pass but has numericly failed

    hipMemcpy(output, output_d, sizeof(float) * num_ele, hipMemcpyDeviceToHost);

    for (int i = 0; i < num_ele; i++) {
        std::cout << (float)output[i] << " ";
    }
    std::cout << std::endl;

    return 0;
}

here is v_mfma_f32_16x16x16_f16 instruciton generated by the kerenel v_mfma_test when I inspect the *.S assemble file:

// with kernel 
    v_mov_b32_e32 v2, 0x2c002c00
    v_mov_b32_e32 v3, v2
    v_mov_b32_e32 v4, 0x3c003c00
    s_load_dwordx2 s[0:1], s[0:1], 0x0
    v_mov_b32_e32 v5, v4
    v_and_b32_e32 v1, 15, v0
    v_lshlrev_b32_e32 v0, 2, v0
    v_mfma_f32_16x16x16_f16 v[2:5], v[2:3], v[4:5], 0
    s_movk_i32 s2, 0xc0
    v_and_or_b32 v0, v0, s2, v1
    v_lshlrev_b32_e32 v0, 2, v0
    s_waitcnt lgkmcnt(0)
    s_nop 2
    global_store_dword v0, v2, s[0:1]
    global_store_dword v0, v3, s[0:1] offset:64
    global_store_dword v0, v4, s[0:1] offset:128
    global_store_dword v0, v5, s[0:1] offset:192
yiakwy-xpu-ml-framework-team commented 13 hours ago

The instruction now works as exepcte after insertation of proper fences:

__device__ void __my_sync_warp(void) { __asm__ volatile("s_barrier" ::); }

// a new single warp test for mfma instruction
__global__ void v_mfma_asm_test(float* output, unsigned int lda, unsigned int ldb) {
    int lane_id = threadIdx.x % rocwmma::Constants::AMDGCN_WAVE_SIZE ;

    // coordinates : see ROCM SDK MappingUtils 
    uint global_row = (blockIdx.x * blockDim.x + threadIdx.x) / 64 * M;
    uint global_col = (blockIdx.y * blockDim.y + threadIdx.y)      * N;

    using float16x4 = __attribute__((__vector_size__(4 * sizeof(float16_t)))) float16_t;
    uint a_frag[2] = {0};
    uint b_frag[2] = {0};

    float16x4 *a = reinterpret_cast<float16x4 *>(a_frag);
    float16x4 *b = reinterpret_cast<float16x4 *>(b_frag);

    for (int j=0; j < 4; j++) {
        (*a)[j] = (float16_t)1.f;
        (*b)[j] = (float16_t)1.f;
    }

    __my_sync_warp(); //__syncthreads();

    // the output fragment is stored in 4 x AccVGPRs (see CDNA3 ISA)
    using floatx4 = __attribute__((__vector_size__(4 * sizeof(float)))) float;
    uint acc_frag[4] = {0};

    floatx4 *d = reinterpret_cast<floatx4 *>(acc_frag);

    unsigned int ele_per_thread = M * N / rocwmma::Constants::AMDGCN_WAVE_SIZE ;

    for (int i=0; i < K; i+= K) {

        // uint a_offset = i + global_row * lda;
        // uint b_offset = i + global_col * ldb;

        // rocwmma::load_matrix_sync(a_frag, a + a_offset, M);
        // rocwmma::load_matrix_sync(b_frag, b + b_offset, K);

        // 4 x outer product accumulation
        asm volatile("v_mfma_f32_16x16x16_f16 "
                        "%0, "
                        "%1, "
                        "%2, 0;\n"
                        : "=v"(*d)
                        :  "v"(*a), 
                           "v"(*b));

    }

    __my_sync_warp(); //__syncthreads();

    // store back to global memory
    for (int j=0; j < ele_per_thread; j++) {
        auto x = ( threadIdx.x / 16 ) % 4;
        auto y = threadIdx.x % 16; // y + j * 16 is the output data lane ID
        auto outIdx = (global_row * N + global_col) + y + j * 16 + x * 64;
        output[outIdx] = (*d)[j];
    }
}

This should output :

截屏2024-09-29 20 10 53