tlc-pack / tvm-tensorir

Apache License 2.0
8 stars 0 forks source link

[DISCUSS] Tensorization of Warp Level Primitives #371

Open tqchen opened 3 years ago

tqchen commented 3 years ago

AMD and Nvidia's MFMA(matrix multiplication operators) operates on the warp level. This creates some interesting challenges for tensorization, semantics checks and tensorization infra. This is a discuss issue that tries to capture some of these questions.

Case Study Example, AMD's mfma32x31x1_f32 instruction

In AMD's case, the GPU have warp size = 64. e.g. the operations are done collectively by 64 threads, where the input and outputs are distributed along the registers of each thread. To make the presentation simple, we will use the following notations

The AMD's mfma32x31x1_f32 is a batched matmul instruction that performs two matrix outer products, to see what happens, the instruction is equivalent to

Logical Semantics

This is a batch matrix multiplication that divides the warp data into 2 of 32x32 groups and perform the matmul

for b, i, j in grid(2, 32, 32):
  C[b, i, j] += A[b, i] * B[b, j]

# matrix form
for b in grid(2):
  C[b, :, :] += dot(A[b, :], B[b, :].T)

In order to implement the above logical semantics, the C[2, 32, 32] , B[2, 32] and A[2, 32] are stored as special registers in warp memory, using the following rule (<=> means the memory map relation, wid is the warp index):

for x, wid in grid(32, 64) 
   warpC[x][wid] <=> C[x//16, x % 16 // 4 * 8 + x % 4 + wid // 32 * 4, wid % 32]

for wid in grid(64) 
   warpA[wid] <=> A[wid/32,  wid%32]

for wid in grid(64) 
   warpB[wid] <=> B[wid/32,  wid%32]

Namely, the data are of A, B and C needs to be layed out in a special way in the warp level memory, which in term maps to the corresponding registers(by removing the wid component.

The actual gpu code looks like follows(use a simple example to illustrate the intrinsics)

// perform a 2 batches of 32x32 matmul
kernel mfma_kernel(float *globalA[2, 32], float* globalB[2, 32], float* globalC[2, 32, 32]) {
   // assume a single thread
   int wid = threadIdx.x;
   // only need to allocate one A register
   float rA[1], rb[1];
   // special register to store results, need 16*2 registers per thread to represent warpC
   special_result_float rC[32] = {0};

   rA[0] = globalA[wid/32, wid % 32];
   rB[0] = globalB[wid/32, wid % 32];
   // run the intrisnic
   __mfma_32x32x1_f32(rC, rA, rB);
  // store back
  for (int i = 0; i < 32; ++i) {
     global[i / 16, i % 16 / 4 * 8 + i % 4 + wid / 32 * 4, wid % 32] = rC[i];
  }
}

mfma_kernel<<sizeof(threadIdx.x)=64, sizeof(blockIdx.x)=1>>(globalA, globalB, globalC)

The above kernel performs

for b, i, j in grid(2, 32, 32):
    globalC[b, i, j] = globalA[b, i] * globalB[b, j] 

In order to perform the matrix multiplication(tensorization) we need to perform the following steps:

Using BatchMatMul Intrinsic to Implement Matmul

It is possible to use Batch matmul intrinsic above to implement matmul(by replicating one side of the element). The logic is as follows (defining BB, AA, CC as the inputs and outputs of the matmul):

# replicate BB on all batches 
for b, j in grid(2, 32):
    B[b, j] <=> BB[j] 

for b, j in grid(2, 32):
    A[b, j] <=> AA[b* 32 + j] 

for b, i, j in grid(2, 32, 32):
    C[b, i, j] <=> CC[b* 32 +i, j] 

Then we have the following relationship:


for b, i, j in grid(2, 32, 32):
  C[b, i, j] += A[b, i] * B[b, j]

maps to <=>

for b, i, j in grid(2, 32, 32):
  CC[b* 32 + i, j] += A[b* 32+ i] * BB[j]

Which is exactly a 64x32 matmul

Challenges and Questions

We can find the following challenges that arises when tensorizing a wrap level primitives.

It would be useful to discuss possible ways to solve these challenges, for example:

vinx13 commented 3 years ago

CUDA mma equivalent https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions

tqchen commented 3 years ago

def matmul():
    for i0, j0, k0 in grid(8, 8, 1):
        CC[i0, j0]+= AA[i0, k0] * BB[j0, k0]

def func():
    for i, j, k in grid(128, 128, 128):
        C[i, j]+= A[i, k] * B[j, k]

def func_step0():
    for i1, j1, k1 in grid(16, 16, 128):
        for i0, j0, k0 in grid(8, 8, 1):
            C[i, j]+= A[i, k] * B[j, k]

def func_step1():
    for i1, j1, k1 in grid(16, 16, 128):
        for ia, ka in grid(8, 1):
            AA[ia, ka] = A[i1 * 8 + ia, k1 + ka]
        for ib, kb in grid(8, 1):
            BB[ib, kb] = B[j1 * 8 + ib, k1 + ka]

        with tensorized:
            for i0, j0, k0 in grid(8, 8, 1):
                CC[i, j]+= AA[i, k] * BB[j, k]

        for ic, jc in grid(8, 8):
            C[i0 * 8 + ic, j0* 8 + jc] = CC[ic, jc]

def func_step2():
    for i1, j1, k1 in grid(16, 16, 128):
        for ia, ka in grid(8, 1):
            Awarp[ia % 2, ia // 4] = A[i1 * 8 + ia, k1 + ka]

        for ib, kb in grid(8, 1):
            BB[ib, kb] = B[j1 * 8 + ib, k1 + ka]

        for ia, ka in grid(8, 1):
            AA[ia, ka] = A[ia, ka]

        with tensorized:
            for i0, j0, k0 in grid(8, 8, 1):
                CC[i, j]+= AA[i, k] * BB[j, k]

        for ic, jc in grid(8, 8):
            C[i0 * 8 + ic, j0* 8 + jc] = CC[ic, jc]

def func_step3():
    for i1, j1, k1 in grid(16, 16, 128):

        for i in grid(2):
            for wid in thread_binding("warpIndex", 4):
                Awarp[i, wid] = A[i1 * 8 + wid*2 +i, k1 + ka]

        for ib, kb in grid(8, 1):
            BB[ib, kb] = B[j1 * 8 + ib, k1 + ka]

        for ia, ka in grid(8, 1):
            AA[ia, ka] = A[ia, ka]

        with tensorized:
            for i0, j0, k0 in grid(8, 8, 1):
                CC[i, j]+= AA[i, k] * BB[j, k]

        for ic, jc in grid(8, 8):
            C[i0 * 8 + ic, j0* 8 + jc] = CC[ic, jc]
Hzfengsy commented 3 years ago

I have thought about a new proposal for TensorCore. Would like to have some discussion :)

Main Idea: wmma load/store changes data layout.

Currently, we write load/store intrin desc like following codes:

with tir.block([16, 16], "store") as [vi, vj]:
    AA[vi, vj] = A[vi, vj]

However, the true behavior of load/store is that(assume that we have a 16*16 warp op):

with tir.block([16, 16], "store") as [vi, vj]:
    AA[vi // 16, vj // 16, vi % 16, vj % 16] = A[vi, vj]

Hardware behavior

The warp fragment memory is somehow continuous (at least at CUDA level).

With wmma API, we declare a warp memory using wmma::fragment[N] with N fragments, which is similar to float16 data[N][16][16]. Note that the memory is compact. Just like a packed_layout at the warp memory level.

Cache_read/write with re-layout support

To support this memory layout transformation during the schedule, we need to introduce a new primitive. The current cache_read copies memory and lets all consumers read data from the cached memory in the same layout. however, we can enhance it with index shaffling (somehow like CUDA swizzle). Here is an example:

AA = s.cache_read(A, lambda i, j: i // 16, j // 16, i % 16, j % 16)

And the generated IR is

with tir.block([n, m]) as [i, j]:
    AA[i // 16, j // 16, i % 16, j % 16] = A[i, j]

Benefits

  1. No need for the affine map. The only thing we need is a mapping from i, j -> i // 16, j // 16, i % 16, j % 16 and bijective is not required.
  2. Enable storage_align and swizzle
  3. Native support for Tensorcore, no need to consider the warp during tensorize.
  4. May also work on other accelerators.
vinx13 commented 3 years ago

I have elaborated a bit the workflow: in the schedule:

@tvm.script.tir
def intrin_desc(a: ty.handle, b: ty.handle, c: ty.handle):
  # desc in like valilla matmul, with special buffer scope
  A = tir.match_buffer(a, shape=(16, 16), scope='warp.layoutA')
  B = tir.match_buffer(b, shape=(16, 16), scope='warp.layoutB')
  C = tir.match_buffer(c, shape=(16, 16), scope='warp.layoutC')
  with block('root', [16, 16, tir.reduce_axis(16)]) as [vi, vj, vk]:
    tir.bind(vi, 0)
    tir.bind(vj, 0)
    tir.bind(vk, 0)
    for i, j, k in tir.grid(16, 16, 16):
      with block('C',  [16, 16, tir.reduce_axis(16)]) as [vii, vji, vki]:
        tir.bind(vii, vi + i)
        tir.bind(vji, vj + j)
        tir.bind(vki, vk + k)
        C[vii, vki] += A[vii, vki] * B[vji, vki]

@tvm.script.tir
def intrin_impl(a: ty.handle, b: ty.handle, c: ty.handle):
  # calling warp level intrinsic
  A = tir.match_buffer(a, shape=(16, 16), scope='warp.layoutA')
  B = tir.match_buffer(b, shape=(16, 16), scope='warp.layoutB')
  C = tir.match_buffer(c, shape=(16, 16), scope='warp.layoutC')
  with block('C', [16, 16, tir.reduce_axis(16)]) as [vii, vji, vki]:
    tir.mma_16x16x16(A, B, C, A_frag_index, B_frag_index, C_frag_index) # fragment indices are computed based on elem_offset, such as A.elem_offset // 256

def schedule_fn(sch):
  # split i, j, k and reorder ...
  sch.reorder(i0, j0, k0, i1, j1, k1)
  AA = sch.cache_read(A, 0, 'warp.layoutA')
  BB = sch.cache_read(B, 0, 'warp.layoutB')
  CC = sch.cache_write(C, 0, 'warp.layoutC')
  sch.compute_at(CC, k0)
  sch.compute_at(AA, k0)
  sch.compute_at(BB, k0)
  sch.tensorize(CC, i1, tensor_intrin)

Special layout can be lowered during buffer flatten. Intrinsic mma_16x16x16 also needs to be lowered to use physical layout, it will become thread-level instructions