halide / Halide

a language for fast, portable data-parallel computation
https://halide-lang.org
Other
5.91k stars 1.07k forks source link

Optimizing Shared Memory Usage for a Tensor Product in Halide on GPU #8427

Closed nullplay closed 2 weeks ago

nullplay commented 1 month ago

Hi,

I'm implementing a tensor product operation in Halide that involves gathering inputs and scattering the final output on a GPU. I'm aiming to optimize shared memory usage for better performance, but I'm encountering some challenges.

Here's a reproduce of my Halide Generator code:

import halide as hl

@hl.generator(name="test")
class Test:
    maxPos = hl.InputScalar(hl.Int(32))
    outSize = hl.InputScalar(hl.Int(32))

    imap = hl.InputBuffer(hl.Int(32), 2)
    omap = hl.InputBuffer(hl.Int(32), 2)

    weight = hl.InputBuffer(hl.Float(32), 3)
    input = hl.InputBuffer(hl.Float(32), 2)
    output = hl.OutputBuffer(hl.Float(32), 2)

    def generate(g):
        imap = g.imap
        omap = g.omap
        weight = g.weight
        input = g.input
        output = g.output
        maxPos = g.maxPos
        outSize = g.outSize

        weight.dim(1).set_bounds(0,64)
        weight.dim(0).set_bounds(0,64)

        # Variable Definition
        o,n,c,p,m = hl.vars("o n c p m")
        m1, m0 = hl.vars("m1 m0")
        p1, p0 = hl.RVar("p1"), hl.RVar("p0")
        c1, c0 = hl.RVar("c1"), hl.RVar("c0")

        # Algorithm
        # 1. Gather Input
        gather_input = hl.Func("GatherInput")
        gather_input[c, p, n] = input[c, hl.unsafe_promise_clamped(imap[p, n], 0, input.dim(0).max())]

        # 2. Load Weight (Identity)
        gather_weight = hl.Func("GatherWeight")
        gather_weight[m, c, n] = weight[m, c, n]

        # 3. Tensor Product
        r1 = hl.RDom([(0, weight.dim(1).extent())])
        product = hl.Func("Product")
        product[m, p, n] = 0.0
        product[m, p, n] += gather_input[r1.x, p, n] * gather_weight[m, r1.x, n]

        # 4. Scatter Product to Output
        r2 = hl.RDom([(0, weight.dim(2).extent()), (0, maxPos)])
        output[m,o] = 0.0
        output[m, hl.unsafe_promise_clamped(omap[r2.y, r2.x], 0, outSize)] += product[m, r2.y, r2.x]

        # Schedule
        (output
            .reorder(m, o)
            .gpu_blocks(o).gpu_threads(m)
        )

        (output.update(0)
            .tile(r2.y, m, p1, m1, 128, 32)
            .tile(p1, m1, p0, m0, 4, 1)
            .reorder(m0, p0, m1, p1, m, r2.y, r2.x)
            .atomic().gpu_blocks(m,r2.y).gpu_threads(m1,p1)
        )

        (product
            .compute_at(output, m1)
            .store_in(hl.MemoryType.Register)
            .update(0)
            .split(r1.x, c1, c0, 16)
            .reorder(m,c0,p,c1,n)
        )

        (gather_weight
            .compute_at(product, c1)
            .store_in(hl.MemoryType.GPUShared)
        )

with hl.GeneratorContext(hl.Target("host-cuda")):
    gen = Test()
f = gen.compile_to_callable()

Outcome :

produce output:
  gpu_block o<Default_GPU>:
    gpu_thread m<Default_GPU>:
      output(...) = ...
  for r39:
    gpu_block r39.r39<Default_GPU>:
      gpu_block m.m<Default_GPU>:
        gpu_thread r39.p1.p1 in [0, 31]<Default_GPU>:
          gpu_thread m.m1.m1 in [0, 31]<Default_GPU>:
            produce Product:
              for p:
                Product(...) = ...
              for r28.c1:
                produce GatherWeight:
                  for c:
                    GatherWeight(...) = ...
                consume GatherWeight:
                  for p:
                    for r28.c0 in [0, 15]:
                      Product(...) = ...
            consume Product:
              for r39.p1.p0 in [0, 3]:
                output(...) = ...

let t197 = (maxPos + 127)/128
  let t199 = maxPos/128
  let t198 = (output.extent.0 + 31)/32
  let t201 = output.min.1*output.stride.1
  let t200 = (input.min.1*input.stride.1) + input.min.0
  for (output.s1.r39$x, 0, weight.extent.2) {
   let t204 = ((output.s1.r39$x - omap.min.1)*omap.stride.1) - omap.min.0
   let t203 = ((output.s1.r39$x - imap.min.1)*imap.stride.1) - imap.min.0
   let t202 = (output.s1.r39$x*weight.stride.2) + output.min.0
   gpu_block<CUDA> (output.s1.r39$y.r39$y.block_id_y, 0, t197) {
    gpu_block<CUDA> (output.s1.m.m.block_id_x, 0, t198) {
     allocate GatherWeight.0[float32 * 16384] in GPUShared
     gpu_thread<CUDA> (.thread_id_y, 0, 32) {
      gpu_thread<CUDA> (.thread_id_x, 0, 32) {
       if (output.s1.r39$y.r39$y.block_id_y < t199) {
        allocate Product.0[float32 * 4] in Register
        produce Product {
         let Product.s0.p.loop_extent.s = (maxPos - (output.s1.r39$y.r39$y.block_id_y*128)) - (.thread_id_y*4)
         let t205 = min(Product.s0.p.loop_extent.s, 4)
         for (Product.s0.p.rebased, 0, t205) {
          Product.0[Product.s0.p.rebased] = 0.000000f
         }
         let t173.s = (output.s1.m.m.block_id_x*32) + t202
         let t208 = min(Product.s0.p.loop_extent.s, 4)
         let t209 = (((output.s1.r39$y.r39$y.block_id_y*32) + .thread_id_y)*4) + t203
         let t206 = .thread_id_x + t173.s
         let t207 = (.thread_id_y*32) + .thread_id_x
         for (Product.s1.r28$x.c1, 0, 4) {
          produce GatherWeight {
           let t210 = Product.s1.r28$x.c1*16
           for (GatherWeight.s0.c.rebased, 0, 16) {
            GatherWeight.0[(GatherWeight.s0.c.rebased*1024) + t207] = weight[((GatherWeight.s0.c.rebased + t210)*weight.stride.1) + t206]
           }
          }
          consume GatherWeight {
           let t178 = (Product.s1.r28$x.c1*16) - t200
           for (Product.s1.p.rebased, 0, t208) {
            let t179 = Product.s1.p.rebased + t209
            for (Product.s1.r28$x.c0, 0, 16) {
             Product.0[Product.s1.p.rebased] = Product.0[Product.s1.p.rebased] + (input[((imap[t179]*input.stride.1) + t178) + Product.s1.r28$x.c0]*GatherWeight.0[(Product.s1.r28$x.c0*1024) + t207])
            }
           }
          }
         }
        }
        consume Product {
         let t181.s = (output.s1.m.m.block_id_x*32) - t201
         let t212 = (((output.s1.r39$y.r39$y.block_id_y*32) + .thread_id_y)*4) + t204
         let t211 = .thread_id_x + t181.s
         for (output.s1.r39$y.p1.p0, 0, 4) {
          let t139 = (omap[output.s1.r39$y.p1.p0 + t212]*output.stride.1) + t211
          let t140 = Product.0[output.s1.r39$y.p1.p0]
          atomic (output) {
            output[t139] = output[t139] + t140
          }
         }
        }
        free Product.0

Objective:

I want to achieve the following optimizations on the GPU:

  1. Accumulate the product in a 4x1 (p0 x m0) register block.

    • This is successfully achieved using:
      product.compute_at(output, m1).store_in(hl.MemoryType.Register)
  2. Load gather_weight into shared memory at the outer reduction loop (c1) in product.

    • Inside c1, gather_weight requires m1(32) x c0(16) = 512 elements.
    • Per GPU block, there are m1(32) x p1(32) GPU threads. Since c0 is independent of the p1 dimension (the y-axis of GPU threads), we can reuse gather_weight across threads if we load it into shared memory.
    • Ideal Scenario:
      • Shared Memory Allocation: Allocate only m1(32) x c0(16) = 512 elements.
      • Data Loading: Use only a subset of GPU threads, such as m1(32) x (p1/2)(16), to load gather_weight into shared memory.

Issue Encountered:


Attempted Solution:

I tried adjusting the schedule to bring gather_weight computation at output instead of product:

(gather_weight
            .compute_at(output, m)
            .store_in(hl.MemoryType.GPUShared)
            .split(c, wc1, wc0, 16)
            .split(m, m1, m0, 32)
            .gpu_threads(m0, wc0)
)

loop nest and conceptual stmt :

produce output:
  gpu_block o<Default_GPU>:
    gpu_thread m<Default_GPU>:
      output(...) = ...
  for r39:
    gpu_block r39.r39<Default_GPU>:
      gpu_block m.m<Default_GPU>:
        produce GatherWeight:
          for c.wc1:
            gpu_thread c.wc0 in [0, 15]<Default_GPU>:
              gpu_thread m.m0 in [0, 31]<Default_GPU>:
                GatherWeight(...) = ...
        consume GatherWeight:
          gpu_thread r39.p1.p1 in [0, 31]<Default_GPU>:
            gpu_thread m.m1.m1 in [0, 31]<Default_GPU>:
              produce Product:
                for p:
                  Product(...) = ...
                for r28.c1:
                  for p:
                    for r28.c0 in [0, 15]:
                      Product(...) = ...
              consume Product:
                for r39.p1.p0 in [0, 3]:
                  output(...) = ...

let t160 = (maxPos + 127)/128
  let t161 = (output.extent.0 + 31)/32
  let t163 = output.min.1*output.stride.1
  let t162 = (input.min.1*input.stride.1) + input.min.0
  for (output.s1.r39$x, 0, weight.extent.2) {
   let t166 = ((output.s1.r39$x - omap.min.1)*omap.stride.1) - omap.min.0
   let t165 = ((output.s1.r39$x - imap.min.1)*imap.stride.1) - imap.min.0
   let t164 = (output.s1.r39$x*weight.stride.2) + output.min.0
   gpu_block<CUDA> (output.s1.r39$y.r39$y.block_id_y, 0, t160) {
    gpu_block<CUDA> (output.s1.m.m.block_id_x, 0, t161) {
     allocate GatherWeight.0[float32 * 2048] in GPUShared
     gpu_thread<CUDA> (.thread_id_y, 0, 32) {
      gpu_thread<CUDA> (.thread_id_x, 0, 32) {
       allocate Product.0[float32 * 4] in Register
       if (.thread_id_y < 16) {
        produce GatherWeight {
         let t143.s = (output.s1.m.m.block_id_x*32) + t164
         let t167 = .thread_id_x + t143.s
         for (GatherWeight.s0.c.wc1, 0, 4) {
          let t158 = (GatherWeight.s0.c.wc1*16) + .thread_id_y
          GatherWeight.0[(t158*32) + .thread_id_x] = weight[(t158*weight.stride.1) + t167]
         }
        }
       }
       gpu_thread_barrier(2)
       consume GatherWeight {
        produce Product {
         let Product.s0.p.loop_extent.s = (maxPos - (output.s1.r39$y.r39$y.block_id_y*128)) - (.thread_id_y*4)
         let t168 = min(Product.s0.p.loop_extent.s, 4)
         for (Product.s0.p.rebased, 0, t168) {
          Product.0[Product.s0.p.rebased] = 0.000000f
         }
         let t169 = min(Product.s0.p.loop_extent.s, 4)
         let t170 = (((output.s1.r39$y.r39$y.block_id_y*32) + .thread_id_y)*4) + t165
         for (Product.s1.r28$x.c1, 0, 4) {
          let t148 = (Product.s1.r28$x.c1*16) - t162
          let t171 = Product.s1.r28$x.c1*16
          for (Product.s1.p.rebased, 0, t169) {
           let t151 = Product.s1.p.rebased + t170
           for (Product.s1.r28$x.c0, 0, 16) {
            Product.0[Product.s1.p.rebased] = Product.0[Product.s1.p.rebased] + (input[((imap[t151]*input.stride.1) + t148) + Product.s1.r28$x.c0]*GatherWeight.0[((Product.s1.r28$x.c0 + t171)*32) + .thread_id_x])
           }
          }
         }
        }
        consume Product {
         let output.s1.r39$y.p1.p0.epilogue.s = maxPos - (((output.s1.r39$y.r39$y.block_id_y*32) + .thread_id_y)*4)
         let t154.s = (output.s1.m.m.block_id_x*32) - t163
         let t172 = max(min(output.s1.r39$y.p1.p0.epilogue.s, 4), 0)
         let t174 = (((output.s1.r39$y.r39$y.block_id_y*32) + .thread_id_y)*4) + t166
         let t173 = .thread_id_x + t154.s
         for (output.s1.r39$y.p1.p0, 0, t172) {
          let t111 = (omap[output.s1.r39$y.p1.p0 + t174]*output.stride.1) + t173
          let t112 = Product.0[output.s1.r39$y.p1.p0]
          atomic (output) {
            output[t111] = output[t111] + t112
          }
         }
        }
        free Product.0
       }
      }
     }
     free GatherWeight.0
    }
   }
  }
 }
}

Is there a way to adjust the Halide schedule to achieve this shared memory usage? I think .compute_at(product, c1) is necessary at some point, but I don't know how to bring this shared memory loads inside c1 with my requirements. I feel I'm almost there, or is this type of loop nest what halide wasn't meant to designed for?