NVIDIA / cutlass

CUDA Templates for Linear Algebra Subroutines
Other
5.61k stars 955 forks source link

[QST] How to call cutlass API within a cuda kernel? #1164

Closed mathfirst closed 11 months ago

mathfirst commented 1 year ago

I would like to use cutlass to perform matrix multiplication within a cuda kernel. Specifically, before the matrix multiplication, I need to do something to load the input matrices A(mxk) and B(kxn) onto shared memory, then perform the matrix multiplication C=AB (mxn), after C obtained, I need to do something on C. The code snippet is something like this.

#define BLOCK_SIZE 32
template<typename scalar_t>
__global__ void involved_kernel_including_matmul(
    const torch::PackedTensorAccessor<scalar_t, 2, torch::RestrictPtrTraits, size_t> A,
    const torch::PackedTensorAccessor<scalar_t, 2, torch::RestrictPtrTraits, size_t> B,
    torch::PackedTensorAccessor<scalar_t, 2, torch::RestrictPtrTraits, size_t> C
){
    // Shared memory used to store Asub and Bsub, respectively.
    // shared memory is shared within one thread block.
    __shared__ float Asub[BLOCK_SIZE][BLOCK_SIZE];  
    __shared__ float Bsub[BLOCK_SIZE][BLOCK_SIZE];
    __shared__ float Si[BLOCK_SIZE][BLOCK_SIZE];
    // Each thread calculates one entry of C by accumulating results into tmpSum
    int col = blockDim.x * blockIdx.x + threadIdx.x;
    int row = blockDim.y * blockIdx.y + threadIdx.y;
    unsigned int col_sub = threadIdx.x; // Thread column within Csub
    unsigned int row_sub = threadIdx.y; // Thread row within Csub
    int height_A = A.size(0), width_A = A.size(1), width_B = B.size(1);
    if (row >= height_A || col >= width_B) return;
    double tmpSum = 0.0;
    for (int m=0; m<width_A/BLOCK_SIZE; m++)
    {
        // I need to do some initialization before loading A and B onto shared memory.
        ...
        Asub[row_sub][col_sub] = A[row][m*BLOCK_SIZE + col_sub];
        Bsub[row_sub][col_sub] = B[m*BLOCK_SIZE + row_sub][col];
        // Synchronize to ensure the sub-matrices are loaded before computation.
        __syncthreads();
        // The following mm is naive and slow, how can I replace it with cutlass API and speed up my implementation
        for (int k=0; k<BLOCK_SIZE; k++)
        {
            tmpSum += Asub[row_sub][k] * Bsub[k][col_sub];
        }
        Si[row_sub][col_sub] = tmpSum;
        __syncthreads();
       // with Si, I need to do some computation later.
       ...
    }
}

Actually, the question is how to use cutlass within a cuda kernel. I am new to cutlass, it seems hard for me to use cutlass, anybody can show me an example with least code for my question? Any instruction will be appreciated.

hwu36 commented 1 year ago

which architecture are you using?

cutlass requires data to be stored in a special swizzled layout to prevent bank conflicts. not like what your pseudo code does. maybe you can just use wmma in your code?

mathfirst commented 1 year ago

Thanks for your nice reply. I use A100 GPU. Yes, your suggestion is good and I'll try. I was wondering, in terms of speedup, cutlass is much faster than using wmma directly? or they are comparable?

hwu36 commented 1 year ago

cutlass is faster but more complex. the tradeoff you need to make.

mnicely commented 12 months ago

@mathfirst, FYI, we plan to release cuBLASDx in the next few months for executing GEMMs in a CUDA kernel. You can check out cuFFTDx to get an idea of the API and intent. https://docs.nvidia.com/cuda/cufftdx/index.html

mathfirst commented 11 months ago

@mnicely That sounds great. I tried wmma but it is not as fast as expected. It may need some optimization. Looking forward to cuBLASDx! I will check out cuFFTDx as you suggested.

mnicely commented 11 months ago

Closing this for now. Keep an eye out for cuBLASDx, as I think it will be the best solution

mnicely commented 8 months ago

cuBLASDX examples have been posted https://github.com/NVIDIA/CUDALibrarySamples/tree/master/MathDx

mathfirst commented 8 months ago

cuBLASDX examples have been posted https://github.com/NVIDIA/CUDALibrarySamples/tree/master/MathDx

Thanks for your hard work!