// CUDA kernel for optimized low-precision GEMM operation using shared memory and tiling
global void gemm_lowbit_kernel(
const halfrestrict A,
const halfrestrict B,
half* restrict C,
int M, int N, int K)
{
// Shared memory for tiles of A and B
shared half As[TILE_SIZE][TILE_SIZE];
shared half Bs[TILE_SIZE][TILE_SIZE];
// Calculate row and column indices of C element to work on
int row = blockIdx.y * TILE_SIZE + threadIdx.y; // Row index of C to compute
int col = blockIdx.x * TILE_SIZE + threadIdx.x; // Column index of C to compute
// Initialize the accumulator to zero
float sum = 0.0f;
// Loop over tiles
for (int t = 0; t < (K + TILE_SIZE - 1) / TILE_SIZE; ++t) {
// Load elements of A and B into shared memory if within bounds
if (row < M && (t * TILE_SIZE + threadIdx.x) < K)
As[threadIdx.y][threadIdx.x] = A[row * K + t * TILE_SIZE + threadIdx.x];
else
As[threadIdx.y][threadIdx.x] = __float2half(0.0f);
if (col < N && (t * TILE_SIZE + threadIdx.y) < K)
Bs[threadIdx.y][threadIdx.x] = B[(t * TILE_SIZE + threadIdx.y) * N + col];
else
Bs[threadIdx.y][threadIdx.x] = __float2half(0.0f);
__syncthreads(); // Synchronize to ensure data is loaded
// Compute partial dot product for this tile
#pragma unroll
for (int k = 0; k < TILE_SIZE; ++k) {
half a_element = As[threadIdx.y][k];
half b_element = Bs[k][threadIdx.x];
sum += __half2float(__hmul(a_element, b_element));
}
__syncthreads(); // Synchronize before loading the next tile
}
// Write the result to the output matrix if within bounds
if (row < M && col < N)
C[row * N + col] = __float2half(sum);
}
// Wrapper function to call the CUDA kernel
void gemm_lowbit_cuda(at::Tensor a, at::Tensor b, at::Tensor c, int M, int N, int K) {
// Ensure that input tensors are contiguous and on the correct device
a = a.contiguous();
b = b.contiguous();
c = c.contiguous();
// Define block and grid dimensions
dim3 threads(TILE_SIZE, TILE_SIZE);
dim3 blocks((N + TILE_SIZE - 1) / TILE_SIZE, (M + TILE_SIZE - 1) / TILE_SIZE);
// Get the CUDA stream from PyTorch
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// Launch the optimized kernel
gemm_lowbit_kernel<<<blocks, threads, 0, stream>>>(
reinterpret_cast<const half*>(a.data_ptr<at::Half>()),
reinterpret_cast<const half*>(b.data_ptr<at::Half>()),
reinterpret_cast<half*>(c.data_ptr<at::Half>()),
M, N, K);
// Check for kernel launch errors
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("CUDA kernel launch failed: %s\n", cudaGetErrorString(err));
}
}`
Upvote & Fund
We're using Polar.sh so you can upvote and help fund this issue.
We receive the funding once the issue is completed & confirmed by you.
Thank you in advance for helping prioritize & fund our backlog.
Just a few suggested CUDA upgrades:
`#include <torch/extension.h>
include <ATen/ATen.h>
include <ATen/cuda/CUDAContext.h>
// Define TILE_SIZE for shared memory tiling
define TILE_SIZE 16
// CUDA kernel for optimized low-precision GEMM operation using shared memory and tiling global void gemm_lowbit_kernel( const half restrict A, const half restrict B, half* restrict C, int M, int N, int K) { // Shared memory for tiles of A and B shared half As[TILE_SIZE][TILE_SIZE]; shared half Bs[TILE_SIZE][TILE_SIZE];
}
// Wrapper function to call the CUDA kernel void gemm_lowbit_cuda(at::Tensor a, at::Tensor b, at::Tensor c, int M, int N, int K) { // Ensure that input tensors are contiguous and on the correct device a = a.contiguous(); b = b.contiguous(); c = c.contiguous();
}`
Upvote & Fund