kyegomez / BitNet

Implementation of "BitNet: Scaling 1-bit Transformers for Large Language Models" in pytorch
https://discord.gg/qUtxnK2NMf
MIT License
1.69k stars 155 forks source link

CUDA Optimization #67

Open simulanics opened 4 weeks ago

simulanics commented 4 weeks ago

Just a few suggested CUDA upgrades:

`#include <torch/extension.h>

include <ATen/ATen.h>

include <ATen/cuda/CUDAContext.h>

// Define TILE_SIZE for shared memory tiling

define TILE_SIZE 16

// CUDA kernel for optimized low-precision GEMM operation using shared memory and tiling global void gemm_lowbit_kernel( const half restrict A, const half restrict B, half* restrict C, int M, int N, int K) { // Shared memory for tiles of A and B shared half As[TILE_SIZE][TILE_SIZE]; shared half Bs[TILE_SIZE][TILE_SIZE];

// Calculate row and column indices of C element to work on
int row = blockIdx.y * TILE_SIZE + threadIdx.y;  // Row index of C to compute
int col = blockIdx.x * TILE_SIZE + threadIdx.x;  // Column index of C to compute

// Initialize the accumulator to zero
float sum = 0.0f;

// Loop over tiles
for (int t = 0; t < (K + TILE_SIZE - 1) / TILE_SIZE; ++t) {
    // Load elements of A and B into shared memory if within bounds
    if (row < M && (t * TILE_SIZE + threadIdx.x) < K)
        As[threadIdx.y][threadIdx.x] = A[row * K + t * TILE_SIZE + threadIdx.x];
    else
        As[threadIdx.y][threadIdx.x] = __float2half(0.0f);

    if (col < N && (t * TILE_SIZE + threadIdx.y) < K)
        Bs[threadIdx.y][threadIdx.x] = B[(t * TILE_SIZE + threadIdx.y) * N + col];
    else
        Bs[threadIdx.y][threadIdx.x] = __float2half(0.0f);

    __syncthreads();  // Synchronize to ensure data is loaded

    // Compute partial dot product for this tile
    #pragma unroll
    for (int k = 0; k < TILE_SIZE; ++k) {
        half a_element = As[threadIdx.y][k];
        half b_element = Bs[k][threadIdx.x];
        sum += __half2float(__hmul(a_element, b_element));
    }

    __syncthreads();  // Synchronize before loading the next tile
}

// Write the result to the output matrix if within bounds
if (row < M && col < N)
    C[row * N + col] = __float2half(sum);

}

// Wrapper function to call the CUDA kernel void gemm_lowbit_cuda(at::Tensor a, at::Tensor b, at::Tensor c, int M, int N, int K) { // Ensure that input tensors are contiguous and on the correct device a = a.contiguous(); b = b.contiguous(); c = c.contiguous();

// Define block and grid dimensions
dim3 threads(TILE_SIZE, TILE_SIZE);
dim3 blocks((N + TILE_SIZE - 1) / TILE_SIZE, (M + TILE_SIZE - 1) / TILE_SIZE);

// Get the CUDA stream from PyTorch
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

// Launch the optimized kernel
gemm_lowbit_kernel<<<blocks, threads, 0, stream>>>(
    reinterpret_cast<const half*>(a.data_ptr<at::Half>()),
    reinterpret_cast<const half*>(b.data_ptr<at::Half>()),
    reinterpret_cast<half*>(c.data_ptr<at::Half>()),
    M, N, K);

// Check for kernel launch errors
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
    printf("CUDA kernel launch failed: %s\n", cudaGetErrorString(err));
}

}`

Upvote & Fund

Fund with Polar