Bruce-Lee-LY / cuda_hgemm

Several optimization methods of half-precision general matrix multiplication (HGEMM) using tensor core with WMMA API and MMA PTX instruction.
MIT License
290 stars 66 forks source link

Change to block of 128 by 256 #3

Closed yupei-ms closed 1 year ago

yupei-ms commented 1 year ago

谢谢分享代码!如果我把wmma_async_pg2s.cu 的block_rows and block_cols改成256 和 128,会出现error。我看不出来有什么问题...

./hgemm -M=4096 -N=4096 -K=1024 -profiling_iterations=1 -warmup_iterations=1 -enable_check=true
[HGEMM 2023-09-25 19:00:25 1022624:1022624 tester.h:72 evaluate] ----------------- Evaluating Wmma-Async-Pg2s -----------------
[HGEMM 2023-09-25 19:00:26 1022624:1022624 wmma_async_pg2s.cu:274 initWmmaAsyncPg2s] shmem_max_size: 66 KBytes (67584 Bytes)
[HGEMM 2023-09-25 19:00:30 1022624:1022624 cuda_timer.h:39 end] CUDA Runtime API error = 0700 "cudaErrorIllegalAddress", runtime version: 12000, driver version: 12020
#define BLOCK_ROWS 128
#define BLOCK_COLS 256

#define WARP_ROWS 64
#define WARP_COLS 64

#define BLOCK_ROW_WARPS 4  // BLOCK_COLS / WARP_COLS
#define BLOCK_COL_WARPS 2  // BLOCK_ROWS / WARP_ROWS

#define BLOCK_ROW_TILES 16   // BLOCK_COLS / WMMA_N
#define BLOCK_COL_TILES 8  // BLOCK_ROWS / WMMA_M

#define WARP_ROW_TILES 4  // WARP_COLS / WMMA_N
#define WARP_COL_TILES 4  // WARP_ROWS / WMMA_M

#define WARP_SIZE 32
#define WARPS_PER_BLOCK 8      // BLOCK_ROW_WARPS * BLOCK_COL_WARPS
#define THREADS_PER_BLOCK 256  // WARP_SIZE * WARPS_PER_BLOCK

#define CHUNK_K 2  // 32 / WMMA_K

#define THREAD_COPY_BYTES 16

#define CHUNK_LINE_BYTES 64          // CHUNK_K * WMMA_K * sizeof(half)
#define CHUNK_COPY_LINES_PER_WARP 8  // WARP_SIZE * THREAD_COPY_BYTES / CHUNK_LINE_BYTES
#define CHUNK_COPY_LINE_LANES 4      // WARP_SIZE / CHUNK_COPY_LINES_PER_WARP

#define SHMEM_PADDING 8

#define AB_SHMEM_STRIDE 40  // CHUNK_K * WMMA_K + SHMEM_PADDING

#define C_SHMEM_STRIDE 264  // BLOCK_COLS + SHMEM_PADDING
#define C_SHMEM_OFFSET 64
yupei-ms commented 1 year ago

compute-sanitizer显示问题在write back to gmem

Bruce-Lee-LY commented 1 year ago

不能用这种改法,block分块尺寸影响g2s和s2g的任务分配,有关代码都需要修改

yupei-ms commented 1 year ago

Even for this 256x128, I think there's a bug in correctness validation, and a bug in result write back from smem to gmem.

diff --git a/src/common/matrix.h b/src/common/matrix.h
index 06381dc..5d6e6e8 100644
--- a/src/common/matrix.h
+++ b/src/common/matrix.h
@@ -75,8 +75,12 @@ public:
@@ -75,8 +75,12 @@ public:
diff --git a/src/common/matrix.h b/src/common/matrix.h
index 06381dc..5d6e6e8 100644
--- a/src/common/matrix.h
+++ b/src/common/matrix.h
@@ -75,8 +75,12 @@ public:
diff --git a/src/common/matrix.h b/src/common/matrix.h
index 06381dc..5d6e6e8 100644
--- a/src/common/matrix.h
+++ b/src/common/matrix.h
@@ -75,8 +75,12 @@ public:
         HGEMM_CHECK_EQ(m_row, base->getRow());
         HGEMM_CHECK_EQ(m_col, base->getCol());

+       half m_host_ptr1[m_row*m_col];
+       for (size_t i = 0; i < m_elem_num; ++i) {
+            m_host_ptr1[i] = __float2half(0);
+        }
         HGEMM_CHECK_CUDART_ERROR(
-            cudaMemcpy(m_dev_ptr, base->getHostPtr(), m_elem_num * sizeof(half), cudaMemcpyHostToDevice));
+            cudaMemcpy(m_dev_ptr, m_host_ptr1, m_elem_num * sizeof(half), cudaMemcpyHostToDevice));
     }

     void moveToHost() {
@@ -93,6 +97,7 @@ public:
         double diff = 0.0;
         for (size_t i = 0; i < m_elem_num; ++i) {
             diff = static_cast<double>(std::abs(__half2float(m_host_ptr[i]) - __half2float(base->getHostPtr()[i])));
+            if(diff > 0.5) HLOG("idx %zu, C: %f, Base: %f", i, __half2float(m_host_ptr[i]), __half2float(base->getHostPtr()[i]));
             m_max_diff = std::max(m_max_diff, diff);
             m_avg_diff += diff;
         }
diff --git a/src/main.cu b/src/main.cu
index 16c9444..73b67ee 100644
--- a/src/main.cu
+++ b/src/main.cu
@@ -36,7 +36,7 @@ DEFINE_uint32(M, 512, "M");
 DEFINE_uint32(N, 2048, "N");
 DEFINE_uint32(K, 1024, "K");
 DEFINE_bool(enable_wmma, true, "test WMMA API");
-DEFINE_bool(enable_mma, true, "test MMA PTX instruction");
+DEFINE_bool(enable_mma, false, "test MMA PTX instruction");
 DEFINE_uint32(warmup_iterations, 1, "warmup iteration numbers and average the result");
 DEFINE_uint32(profiling_iterations, 10, "profiling iteration numbers and average the result");
 DEFINE_uint32(sleep_duration, 100, "sleep_milliseconds between profiling");
@@ -103,12 +103,12 @@ int main(int argc, char *argv[]) {
     if (FLAGS_enable_wmma) {
         // tester.evaluate(wmmaNaive, "Wmma-Naive");
         // tester.evaluate(wmmaBase, "Wmma-Base");
-        tester.evaluate(wmmaPadding, "Wmma-Padding");
-        tester.evaluate(wmmaAsync, "Wmma-Async");
+        //tester.evaluate(wmmaPadding, "Wmma-Padding");
+        //tester.evaluate(wmmaAsync, "Wmma-Async");
         tester.evaluate(wmmaAsyncPg2s, "Wmma-Async-Pg2s");
-        tester.evaluate(wmmaAsyncPg2sPs2r, "Wmma-Async-Pg2s-Ps2r");
-        tester.evaluate(wmmaAsyncStage2, "Wmma-Async-Stage2");
-        tester.evaluate(wmmaAsyncStage3, "Wmma-Async-Stage3");
+        //tester.evaluate(wmmaAsyncPg2sPs2r, "Wmma-Async-Pg2s-Ps2r");
+        //tester.evaluate(wmmaAsyncStage2, "Wmma-Async-Stage2");
+        //tester.evaluate(wmmaAsyncStage3, "Wmma-Async-Stage3");
     }

     if (FLAGS_enable_mma) {
diff --git a/src/wmma/wmma_async_pg2s.cu b/src/wmma/wmma_async_pg2s.cu
index e860f22..a14f4c7 100644
--- a/src/wmma/wmma_async_pg2s.cu
+++ b/src/wmma/wmma_async_pg2s.cu
@@ -73,9 +73,9 @@ __global__ void wmmaAsyncPg2sKernel(const half *__restrict__ A, const half *__re
     half *smem_warp_tile_ptr = &smem[0][0] + (warp_id / BLOCK_ROW_WARPS) * C_SMEM_STRIDE * WARP_ROWS +
                                (warp_id % BLOCK_ROW_WARPS) * C_SMEM_OFFSET;

-    half *smem_warp_stream_ptr = &smem[0][0] + warp_id * WMMA_M * C_SMEM_STRIDE;
+    half *smem_warp_stream_ptr = &smem[0][0] + warp_id *2* WMMA_M * C_SMEM_STRIDE;

-    const size_t gmem_idx = (block_tile_i + warp_id) * WMMA_M * N + block_tile_j * WMMA_N;
+    const size_t gmem_idx = (block_tile_i + warp_id*2) * WMMA_M * N + block_tile_j * WMMA_N;
     half *src_gmem_warp_stream_ptr = &C[gmem_idx];

     wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, half> C_frag[WARP_COL_TILES][WARP_ROW_TILES];