halide / Halide

a language for fast, portable data-parallel computation
https://halide-lang.org
Other
5.91k stars 1.07k forks source link

CUDA Cooperative Fetching in Matmul using Halide #8434

Closed nullplay closed 2 weeks ago

nullplay commented 1 month ago

Hi, I'm trying to cooperatively fetch the tiles of the input matrices for matrix multiplication in CUDA.

Below is my hand-written CUDA kernel, which I want to reproduce using Halide:

template <int BLOCK_I, int BLOCK_J, int BLOCK_K>
__global__ void matmul_l1(
    int32_t size_i,
    int32_t size_j,
    int32_t size_k,
    float const *a,
    float const *b,
    float *c) {
    __shared__ float A[BLOCK_I][BLOCK_K];
    __shared__ float B[BLOCK_J][BLOCK_K];

    int tx = threadIdx.x;
    int ty = threadIdx.y;
    int i = blockIdx.x * blockDim.x + threadIdx.x;
    int j = blockIdx.y * blockDim.y + threadIdx.y;

    float sum = 0;
    for (int32_t ko = 0; ko < size_k/BLOCK_K; ko++) {
      if (tx < BLOCK_I && ty < BLOCK_K) {
        int k = ko * BLOCK_K + ty;
        A[tx][ty] = a[i * size_k + k];
      }

      if (ty < BLOCK_J && tx < BLOCK_K) {
        int k = ko * BLOCK_K + tx;
        B[ty][tx] = b[k * size_j + j];
      }
      __syncthreads();

      for (int32_t ki = 0; ki < BLOCK_K; ki++) {
        sum += A[tx][ki] * B[ty][ki];
      }
      __syncthreads();
    }

    c[i * size_j + j] = sum;
}

void launch_matmul_l1(
    int32_t size_i,
    int32_t size_j,
    int32_t size_k,
    float const *a,
    float const *b,
    float *c) {

    const int BLOCK_I = 16;
    const int BLOCK_J = 16;
    const int BLOCK_K = 8;
    dim3 num_blocks = dim3((size_i+BLOCK_I-1)/BLOCK_I,
                           (size_j+BLOCK_J-1)/BLOCK_J,
                           1);
    dim3 block_size = dim3(BLOCK_I,BLOCK_J,1);

    matmul_l1<BLOCK_I, BLOCK_J, BLOCK_K><<<num_blocks, block_size>>>(size_i, size_j, size_k, a, b, c);
    return;
}

And this is the Halide kernel I tried to copy this cuda implementation.

import halide as hl

@hl.generator(name="test")
class Test:
    A = hl.InputBuffer(hl.Float(32), 2)
    B = hl.InputBuffer(hl.Int(32), 2)
    out = hl.OutputBuffer(hl.Float(32), 2)

    def generate(g):
        A = g.A
        B = g.B
        out = g.out

        # Variable Definition
        i,j,io,jo,ii,ji = hl.vars("i j io jo ii ji")
        ko, ki = hl.RVar("ko"), hl.RVar("ki")

        # Algorithm
        prod = hl.Func("prod")
        r = hl.RDom([(0, 256)])
        k = r.x
        prod[i,j] = 0.0
        prod[i,j] += A[i,k] * B[k,j]
        out[i,j] = prod[i,j]

        (out
            .tile(i, j, io, jo, ii, ji, 16, 16)
            .gpu_blocks(io, jo)
            .gpu_threads(ii, ji)
        )
        (prod
            .compute_at(out, ii)
            .update(0)
            .split(k, ko, ki, 8)
            .reorder(i, j, ki, ko)
        )
        A.in_().compute_at(prod, ko)#.store_in(hl.MemoryType.GPUShared)
        B.in_().compute_at(prod, ko)

        out.print_loop_nest()

with hl.GeneratorContext(hl.Target("host-cuda")):
    gen = Test()
f = gen.compile_to_callable()

However, this schedule ends up allocating A and B per thread, not in the shared memory. Is there a way to cooperatively load A and B into shared memory during the ko loop?

I tried A.in_().compute_at(prod, ko).store_in(hl.MemoryType.GPUShared) and A.in_().compute_at(prod, ko).store_at(out, ji), but neither produced the desired code. The second case also resulted in a compiler error.

nullplay commented 1 month ago

This is the Halide stmt with the schedule above.

gpu_block<CUDA> (out.s0.j.jo.block_id_y, 0, t52) {
   gpu_block<CUDA> (out.s0.i.io.block_id_x, 0, t53) {
    gpu_thread<CUDA> (.thread_id_y, 0, 16) {
     gpu_thread<CUDA> (.thread_id_x, 0, 16) {
      let out.s0.j.ji.base.s = min(out.s0.j.jo.block_id_y*16, out.extent.1 + -16)
      allocate prod.0[float32 * 1]
      allocate B_im_global_wrapper$0.2[int32 * 8]
      allocate A_im_global_wrapper$0.1[float32 * 8]
      let out.s0.i.ii.base.s = min(out.s0.i.io.block_id_x*16, out.extent.0 + -16)
      produce prod {
       prod.0[0] = 0.000000f
       let t58 = (((out.s0.j.ji.base.s + t55) + .thread_id_y)*B.stride.1) - B.min.0
       let t57 = (out.s0.i.ii.base.s + t54) + .thread_id_x
       for (prod.s1.r12$x.ko, 0, 32) {
        produce A_im_global_wrapper$0 {
         let t59 = prod.s1.r12$x.ko*8
         for (A_im_global_wrapper$0.s0._1.rebased, 0, 8) {
          A_im_global_wrapper$0.1[A_im_global_wrapper$0.s0._1.rebased] = A[((A_im_global_wrapper$0.s0._1.rebased + t59)*A.stride.1) + t57]
         }
        }
        produce B_im_global_wrapper$0 {
         let t60 = (prod.s1.r12$x.ko*8) + t58
         for (B_im_global_wrapper$0.s0._0.rebased, 0, 8) {
          B_im_global_wrapper$0.2[B_im_global_wrapper$0.s0._0.rebased] = B[B_im_global_wrapper$0.s0._0.rebased + t60]
         }
        }
        consume B_im_global_wrapper$0 {
         consume A_im_global_wrapper$0 {
          for (prod.s1.r12$x.ki, 0, 8) {
           prod.0[0] = prod.0[0] + (A_im_global_wrapper$0.1[prod.s1.r12$x.ki]*float32(B_im_global_wrapper$0.2[prod.s1.r12$x.ki]))
          }
         }
        }
       }
       free B_im_global_wrapper$0.2
       free A_im_global_wrapper$0.1
      }
      consume prod {
       out[((((out.min.1 + out.s0.j.ji.base.s) + .thread_id_y)*out.stride.1) + (out.s0.i.ii.base.s - (out.min.1*out.stride.1))) + .thread_id_x] = prod.0[0]
      }
      free prod.0
     }
    }
mcourteaux commented 1 month ago

You need to schedule the loading of the data to be parallelized by the threads:

        A.in_().compute_at(prod, ko)
        B.in_().compute_at(prod, ko)

Should become something like:

        A.in_().compute_at(prod, ko).gpu_threads(hl._0)
        B.in_().compute_at(prod, ko).gpu_threads(hl._0)

Where _0, _1, _2, ... are special variable in Halide used to refer to dimensions of an input buffer (as opposed to a Func where you can name the dimensions with variables).

I haven't fully studied your code, you might need _1 in some cases. Try and see what the stmt looks like, and you'll figure out how this directive interacts with the data loading.