Open yiakwy-xpu-ml-framework-team opened 3 days ago
Hello @yiakwy-xpu-ml-framework-team, thanks for reaching out! I've created an internal ticket for investigation. Will report back with updates as they become available.
@cgmillette I am not sure if you are the right person for this question, but I hope you can give me some feed back on this three kernels if you are interested in, since I found no direct assembly usage for our wmma instruction.
Luckily I just decode the layout of wmma in AMD chip for CDNA3 arch. But there is still some minor problem for direct use of the instruction.
I proposed three functions
Memory transaction estimation for this instruct
// total loads : 16 x 16 x 2 = 512 bytes
// a memory transaction : 4 (32 bits) x 64 = 256 bytes
// phases : 2 = 512 / 256 (i.e. we need at least two 64 threads warps to load all data from smem into registers VGRP )
// elemens per threads : 16 x 16 / 64 = 4
We can apply the same analysis to the other wmma instruction.
As for memroy layout, I reproduced the single warp result in the scripts with fully explanation.
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <iostream>
using half = __half;
using float16_t = _Float16;
#include <rocwmma/rocwmma.hpp>
#define M 16
#define N 16
#define K 16
__global__ void v_mfma_intrinsics_test(float* output, unsigned int lda, unsigned int ldb) {
int lane_id = threadIdx.x % rocwmma::Constants::AMDGCN_WAVE_SIZE ;
// coordinates : see ROCM SDK MappingUtils
uint global_row = (blockIdx.x * blockDim.x + threadIdx.x) / 64 * M;
uint global_col = (blockIdx.y * blockDim.y + threadIdx.y) * N;
// in this test lda == BLOCK_SIZE_M (M), ldb == BLOCK_SIZE_N (N)
// NOTE(yiakwy) : I produce frag layout based on research https://github.com/yiakwy-xpu-ml-framework-team/AMD-lab-notes-fork/blob/release/matrix-cores/src/mfma_fp32_16x16x16fp16.cpp
// a_frag is stored in 2 x VGPRs pair (32 lanes x 2, see CDNA3 ISA), each thread process 4 fp16 elements;
// lane 0-15 16-31 32-47 48-63
// Reg\Col 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// a_frag[0] x x x x
// a_frag[1] * * * *
// ...
// a_frag[3] + + +
//
// Here, a_frag[0] is accessed by 16 consecutive threads (T0-T15), each thread cover 4 elements
//
// lane
// row\Reg b_frag[0] b_frag[1] b_frag[2] b_frag[3]
// 0 x 0-15
// 1 *
// 2 &
// 3 +
// 0+4 x 16-31
// 1+4 *
// 2+4 &
// 3+4 +
// 0+8 x 32-47
// 1+8 *
// 2+8 &
// 3+8 +
// 0+12 x 48-63
// 1+12 *
// 2+12 *
// 3+12 +
using float16x4 = __attribute__((__vector_size__(4 * sizeof(float16_t)))) float16_t;
uint a_frag[2] = {0};
uint b_frag[2] = {0};
float16x4 *a = reinterpret_cast<float16x4 *>(a_frag);
float16x4 *b = reinterpret_cast<float16x4 *>(b_frag);
// the output fragment is stored in 4 x AccVGPRs (see CDNA3 ISA), each thread process 4 fp32 elements
using floatx4 = __attribute__((__vector_size__(4 * sizeof(float)))) float;
uint acc_frag[4] = {0};
floatx4 *d = reinterpret_cast<floatx4 *>(acc_frag);
unsigned int ele_per_thread = M * N / rocwmma::Constants::AMDGCN_WAVE_SIZE ;
for (int i=0; i < K; i+= K) {
// uint a_warp_offset = i + global_row * lda;
// uint b_warp_offset = i + global_col * ldb;
// rocwmma::load_matrix_sync(a_frag, a + a_offset, M);
// rocwmma::load_matrix_sync(b_frag, b + b_offset, K);
*d = __builtin_amdgcn_mfma_f32_16x16x16f16(*a, *b, *d, 0, 0, 0);
}
// Note(yiakwy) : store thread private memory back to global memory is tricky. Suppose output is still major layout , then threads configuration is (4/*x*/, 4/*j*/, 16/*y*/) :
//
// lane 0 1 15 16 17 32 ... 63
// Row(x) \ Col (j,y) <0,0> <0,1> ... <0,15> <1, 0> <1, 1> ... <1,15> ... <4,15>
// 0 x
// 1 x
// 2 x
// 3 x
//
// x = output[warp_offset + coord2Indx (x, j, y)], coord2Index : (x, j, y) -> ( y + j * 16 ) + x * 64
for (int j=0; j < ele_per_thread; j++) {
auto x = ( threadIdx.x / 16 ) % 4;
auto y = threadIdx.x % 16; // y + j * 16 is the output data lane ID
auto outIdx = (global_row * N + global_col) + y + j * 16 + x * 64;
output[outIdx] = (*d)[j];
}
}
// a new single warp test for mfma instruction
__global__ void v_mfma_asm_test(float* output, unsigned int lda, unsigned int ldb) {
int lane_id = threadIdx.x % rocwmma::Constants::AMDGCN_WAVE_SIZE ;
// coordinates : see ROCM SDK MappingUtils
uint global_row = (blockIdx.x * blockDim.x + threadIdx.x) / 64 * M;
uint global_col = (blockIdx.y * blockDim.y + threadIdx.y) * N;
using float16x4 = __attribute__((__vector_size__(4 * sizeof(float16_t)))) float16_t;
uint a_frag[2] = {0};
uint b_frag[2] = {0};
float16x4 *a = reinterpret_cast<float16x4 *>(a_frag);
float16x4 *b = reinterpret_cast<float16x4 *>(b_frag);
for (int j=0; j < 4; j++) {
(*a)[j] = (half)1.f;
(*b)[j] = (half)1.f;
}
// the output fragment is stored in 4 x AccVGPRs (see CDNA3 ISA)
using floatx4 = __attribute__((__vector_size__(4 * sizeof(float)))) float;
uint acc_frag[4] = {0};
floatx4 *d = reinterpret_cast<floatx4 *>(acc_frag);
for (int i=0; i < K; i+= K) {
// uint a_offset = i + global_row * lda;
// uint b_offset = i + global_col * ldb;
// rocwmma::load_matrix_sync(a_frag, a + a_offset, M);
// rocwmma::load_matrix_sync(b_frag, b + b_offset, K);
unsigned int ele_per_thread = M * N / rocwmma::Constants::AMDGCN_WAVE_SIZE ;
// 4 x outer product accumulation
// *d = __builtin_amdgcn_mfma_f32_16x16x16f16(*a, *b, *d, 0, 0, 0);
asm volatile("v_mfma_f32_16x16x16_f16 "
"%0, "
"%1, "
"%2, 0;\n"
: "=v"(*d)
: "v"(*a),
"v"(*b));
// store back to global memory
for (int j=0; j < ele_per_thread; j++) {
auto x = ( threadIdx.x / 16 ) % 4;
auto y = threadIdx.x % 16; // y + j * 16 is the output data lane ID
auto outIdx = (global_row * N + global_col) + y + j * 16 + x * 64;
output[outIdx] = (*d)[j];
}
}
}
__global__ void v_mfma_test(float* output, unsigned int lda, unsigned int ldb)
{
int lane_id = threadIdx.x % rocwmma::Constants::AMDGCN_WAVE_SIZE ;
// coordinates : see ROCM SDK MappingUtils
uint global_row = (blockIdx.x * blockDim.x + threadIdx.x) / 64 * M;
uint global_col = (blockIdx.y * blockDim.y + threadIdx.y) * N;
rocwmma::fragment<rocwmma::matrix_a, 16, 16, 16, half, rocwmma::row_major> a_frag;
rocwmma::fragment<rocwmma::matrix_b, 16, 16, 16, half, rocwmma::col_major> b_frag;
rocwmma::fragment<rocwmma::accumulator, 16, 16, 16, float> acc_frag; // ComputeT
rocwmma::fill_fragment(acc_frag, 0.0f);
// total loads : 16 x 16 x 2 = 512 bytes
// a memory transaction : 4 (32 bits) x 64 = 256 bytes
// phases : 2 = 512 / 256
// elemens per threads : 16 x 16 / 64 = 4
for (int i = 0; i < K; i+=K) {
rocwmma::fill_fragment(a_frag, (half)(1 / 16.0));
rocwmma::fill_fragment(b_frag, (half)1.0);
// Matrix multiply - accumulate using MFMA units
rocwmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag);
}
// output is row major
rocwmma::store_matrix_sync(output + (global_row * N + global_col), acc_frag, M, rocwmma::mem_row_major);
}
int main(int argc, char * argv[]) {
int device_id = 0;
hipGetDevice(&device_id);
int major = 0, minor = 0;
hipDeviceComputeCapability(&major, &minor, device_id);
std::cout << "Mjaor: " << major << "," << "Minor: " << minor << std::endl;
int max_smem_per_sm = 0;
hipDeviceGetAttribute(
&max_smem_per_sm, hipDeviceAttribute_t::hipDeviceAttributeMaxSharedMemoryPerMultiprocessor, device_id);
std::cout << "Max sems per sm : " << max_smem_per_sm << std::endl;
using DTypeQ = __half;
const int num_ctas_per_sm = max_smem_per_sm > (16 * 64 * sizeof(DTypeQ) * 16) ? 2 : 1;
const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm;
std::cout << "Num ctas per sm : " << num_ctas_per_sm << std::endl;
std::cout << "Max sems per block : " << max_smem_per_threadblock << std::endl;
const int num_warps_z = 4;
const uint32_t max_num_frags_z_smem =
(max_smem_per_threadblock / (16 * 64 * sizeof(DTypeQ)) ) /
(2 * num_warps_z);
std::cout << "max_num_frags_z_smem : " << max_num_frags_z_smem << std::endl;
// tests mfma
size_t num_ele = 16 * 16;
float* output = (float*)malloc(sizeof(float) * num_ele);
float* output_d = nullptr;
hipMalloc(&output_d, sizeof(float) * num_ele);
// 1 x warp (block) test
// v_mfma_test<<<1, 64>>>(output_d, M/*lda*/, N/*ldb*/); // pass with correct answer
// v_mfma_intrinsics_test<<<1, 64>>>(output_d, M/*lda*/, N/*ldb*/); // pass with correct answer
v_mfma_asm_test<<<1, 64>>>(output_d, M/*lda*/, N/*ldb*/); // pass but has numericly failed
hipMemcpy(output, output_d, sizeof(float) * num_ele, hipMemcpyDeviceToHost);
for (int i = 0; i < num_ele; i++) {
std::cout << (float)output[i] << " ";
}
std::cout << std::endl;
return 0;
}
here is v_mfma_f32_16x16x16_f16 instruciton generated by the kerenel v_mfma_test when I inspect the *.S assemble file:
// with kernel
v_mov_b32_e32 v2, 0x2c002c00
v_mov_b32_e32 v3, v2
v_mov_b32_e32 v4, 0x3c003c00
s_load_dwordx2 s[0:1], s[0:1], 0x0
v_mov_b32_e32 v5, v4
v_and_b32_e32 v1, 15, v0
v_lshlrev_b32_e32 v0, 2, v0
v_mfma_f32_16x16x16_f16 v[2:5], v[2:3], v[4:5], 0
s_movk_i32 s2, 0xc0
v_and_or_b32 v0, v0, s2, v1
v_lshlrev_b32_e32 v0, 2, v0
s_waitcnt lgkmcnt(0)
s_nop 2
global_store_dword v0, v2, s[0:1]
global_store_dword v0, v3, s[0:1] offset:64
global_store_dword v0, v4, s[0:1] offset:128
global_store_dword v0, v5, s[0:1] offset:192
The instruction now works as exepcte after insertation of proper fences:
__device__ void __my_sync_warp(void) { __asm__ volatile("s_barrier" ::); }
// a new single warp test for mfma instruction
__global__ void v_mfma_asm_test(float* output, unsigned int lda, unsigned int ldb) {
int lane_id = threadIdx.x % rocwmma::Constants::AMDGCN_WAVE_SIZE ;
// coordinates : see ROCM SDK MappingUtils
uint global_row = (blockIdx.x * blockDim.x + threadIdx.x) / 64 * M;
uint global_col = (blockIdx.y * blockDim.y + threadIdx.y) * N;
using float16x4 = __attribute__((__vector_size__(4 * sizeof(float16_t)))) float16_t;
uint a_frag[2] = {0};
uint b_frag[2] = {0};
float16x4 *a = reinterpret_cast<float16x4 *>(a_frag);
float16x4 *b = reinterpret_cast<float16x4 *>(b_frag);
for (int j=0; j < 4; j++) {
(*a)[j] = (float16_t)1.f;
(*b)[j] = (float16_t)1.f;
}
__my_sync_warp(); //__syncthreads();
// the output fragment is stored in 4 x AccVGPRs (see CDNA3 ISA)
using floatx4 = __attribute__((__vector_size__(4 * sizeof(float)))) float;
uint acc_frag[4] = {0};
floatx4 *d = reinterpret_cast<floatx4 *>(acc_frag);
unsigned int ele_per_thread = M * N / rocwmma::Constants::AMDGCN_WAVE_SIZE ;
for (int i=0; i < K; i+= K) {
// uint a_offset = i + global_row * lda;
// uint b_offset = i + global_col * ldb;
// rocwmma::load_matrix_sync(a_frag, a + a_offset, M);
// rocwmma::load_matrix_sync(b_frag, b + b_offset, K);
// 4 x outer product accumulation
asm volatile("v_mfma_f32_16x16x16_f16 "
"%0, "
"%1, "
"%2, 0;\n"
: "=v"(*d)
: "v"(*a),
"v"(*b));
}
__my_sync_warp(); //__syncthreads();
// store back to global memory
for (int j=0; j < ele_per_thread; j++) {
auto x = ( threadIdx.x / 16 ) % 4;
auto y = threadIdx.x % 16; // y + j * 16 is the output data lane ID
auto outIdx = (global_row * N + global_col) + y + j * 16 + x * 64;
output[outIdx] = (*d)[j];
}
}
This should output :
Problem Description
I am investigating usage of instruction v_mfma_f32_16x16x16_f16 and nvidia equivalent warp-level mma (swizzle SRAM memory + ldmatrix registers + mma over registers, for Ampere arch style computation) with multiple blocks.
And I found our gemm tests failed in ROCM6.2
Operating System
Ubuntu 22.04
CPU
AMD EPYC 9534 64-Core Processor
GPU
AMD Instinct MI300X
ROCm Version
ROCm 6.2.0
ROCm Component
rocWMMA
Steps to Reproduce
buld rocWMMA and trigger tests of gemm_PGR0_LB0_MP0_MB_NC_ad_hoc-validate (multiple blocks, no memory optimization)
(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support
No response
Additional Information
No response