Open tqchen opened 3 years ago
def matmul():
for i0, j0, k0 in grid(8, 8, 1):
CC[i0, j0]+= AA[i0, k0] * BB[j0, k0]
def func():
for i, j, k in grid(128, 128, 128):
C[i, j]+= A[i, k] * B[j, k]
def func_step0():
for i1, j1, k1 in grid(16, 16, 128):
for i0, j0, k0 in grid(8, 8, 1):
C[i, j]+= A[i, k] * B[j, k]
def func_step1():
for i1, j1, k1 in grid(16, 16, 128):
for ia, ka in grid(8, 1):
AA[ia, ka] = A[i1 * 8 + ia, k1 + ka]
for ib, kb in grid(8, 1):
BB[ib, kb] = B[j1 * 8 + ib, k1 + ka]
with tensorized:
for i0, j0, k0 in grid(8, 8, 1):
CC[i, j]+= AA[i, k] * BB[j, k]
for ic, jc in grid(8, 8):
C[i0 * 8 + ic, j0* 8 + jc] = CC[ic, jc]
def func_step2():
for i1, j1, k1 in grid(16, 16, 128):
for ia, ka in grid(8, 1):
Awarp[ia % 2, ia // 4] = A[i1 * 8 + ia, k1 + ka]
for ib, kb in grid(8, 1):
BB[ib, kb] = B[j1 * 8 + ib, k1 + ka]
for ia, ka in grid(8, 1):
AA[ia, ka] = A[ia, ka]
with tensorized:
for i0, j0, k0 in grid(8, 8, 1):
CC[i, j]+= AA[i, k] * BB[j, k]
for ic, jc in grid(8, 8):
C[i0 * 8 + ic, j0* 8 + jc] = CC[ic, jc]
def func_step3():
for i1, j1, k1 in grid(16, 16, 128):
for i in grid(2):
for wid in thread_binding("warpIndex", 4):
Awarp[i, wid] = A[i1 * 8 + wid*2 +i, k1 + ka]
for ib, kb in grid(8, 1):
BB[ib, kb] = B[j1 * 8 + ib, k1 + ka]
for ia, ka in grid(8, 1):
AA[ia, ka] = A[ia, ka]
with tensorized:
for i0, j0, k0 in grid(8, 8, 1):
CC[i, j]+= AA[i, k] * BB[j, k]
for ic, jc in grid(8, 8):
C[i0 * 8 + ic, j0* 8 + jc] = CC[ic, jc]
I have thought about a new proposal for TensorCore. Would like to have some discussion :)
Currently, we write load/store intrin desc like following codes:
with tir.block([16, 16], "store") as [vi, vj]:
AA[vi, vj] = A[vi, vj]
However, the true behavior of load/store is that(assume that we have a 16*16 warp op):
with tir.block([16, 16], "store") as [vi, vj]:
AA[vi // 16, vj // 16, vi % 16, vj % 16] = A[vi, vj]
The warp fragment memory is somehow continuous (at least at CUDA level).
With wmma API, we declare a warp memory using wmma::fragment[N]
with N fragments, which is similar to float16 data[N][16][16]
. Note that the memory is compact. Just like a packed_layout at the warp memory level.
To support this memory layout transformation during the schedule, we need to introduce a new primitive.
The current cache_read
copies memory and lets all consumers read data from the cached memory in the same layout. however, we can enhance it with index shaffling (somehow like CUDA swizzle). Here is an example:
AA = s.cache_read(A, lambda i, j: i // 16, j // 16, i % 16, j % 16)
And the generated IR is
with tir.block([n, m]) as [i, j]:
AA[i // 16, j // 16, i % 16, j % 16] = A[i, j]
i, j -> i // 16, j // 16, i % 16, j % 16
and bijective is not required.storage_align
and swizzle
warp
during tensorize
.I have elaborated a bit the workflow: in the schedule:
@tvm.script.tir
def intrin_desc(a: ty.handle, b: ty.handle, c: ty.handle):
# desc in like valilla matmul, with special buffer scope
A = tir.match_buffer(a, shape=(16, 16), scope='warp.layoutA')
B = tir.match_buffer(b, shape=(16, 16), scope='warp.layoutB')
C = tir.match_buffer(c, shape=(16, 16), scope='warp.layoutC')
with block('root', [16, 16, tir.reduce_axis(16)]) as [vi, vj, vk]:
tir.bind(vi, 0)
tir.bind(vj, 0)
tir.bind(vk, 0)
for i, j, k in tir.grid(16, 16, 16):
with block('C', [16, 16, tir.reduce_axis(16)]) as [vii, vji, vki]:
tir.bind(vii, vi + i)
tir.bind(vji, vj + j)
tir.bind(vki, vk + k)
C[vii, vki] += A[vii, vki] * B[vji, vki]
@tvm.script.tir
def intrin_impl(a: ty.handle, b: ty.handle, c: ty.handle):
# calling warp level intrinsic
A = tir.match_buffer(a, shape=(16, 16), scope='warp.layoutA')
B = tir.match_buffer(b, shape=(16, 16), scope='warp.layoutB')
C = tir.match_buffer(c, shape=(16, 16), scope='warp.layoutC')
with block('C', [16, 16, tir.reduce_axis(16)]) as [vii, vji, vki]:
tir.mma_16x16x16(A, B, C, A_frag_index, B_frag_index, C_frag_index) # fragment indices are computed based on elem_offset, such as A.elem_offset // 256
def schedule_fn(sch):
# split i, j, k and reorder ...
sch.reorder(i0, j0, k0, i1, j1, k1)
AA = sch.cache_read(A, 0, 'warp.layoutA')
BB = sch.cache_read(B, 0, 'warp.layoutB')
CC = sch.cache_write(C, 0, 'warp.layoutC')
sch.compute_at(CC, k0)
sch.compute_at(AA, k0)
sch.compute_at(BB, k0)
sch.tensorize(CC, i1, tensor_intrin)
Special layout can be lowered during buffer flatten. Intrinsic mma_16x16x16
also needs to be lowered to use physical layout, it will become thread-level instructions
AMD and Nvidia's MFMA(matrix multiplication operators) operates on the warp level. This creates some interesting challenges for tensorization, semantics checks and tensorization infra. This is a discuss issue that tries to capture some of these questions.
Case Study Example, AMD's mfma32x31x1_f32 instruction
In AMD's case, the GPU have warp size = 64. e.g. the operations are done collectively by 64 threads, where the input and outputs are distributed along the registers of each thread. To make the presentation simple, we will use the following notations
Will get lowered into the following code in the thread level view.
The AMD's mfma32x31x1_f32 is a batched matmul instruction that performs two matrix outer products, to see what happens, the instruction is equivalent to
Logical Semantics
This is a batch matrix multiplication that divides the warp data into 2 of 32x32 groups and perform the matmul
In order to implement the above logical semantics, the
C[2, 32, 32]
,B[2, 32]
andA[2, 32]
are stored as special registers in warp memory, using the following rule (<=> means the memory map relation, wid is the warp index):Namely, the data are of A, B and C needs to be layed out in a special way in the warp level memory, which in term maps to the corresponding registers(by removing the wid component.
The actual gpu code looks like follows(use a simple example to illustrate the intrinsics)
The above kernel performs
In order to perform the matrix multiplication(tensorization) we need to perform the following steps:
Using BatchMatMul Intrinsic to Implement Matmul
It is possible to use Batch matmul intrinsic above to implement matmul(by replicating one side of the element). The logic is as follows (defining BB, AA, CC as the inputs and outputs of the matmul):
Then we have the following relationship:
Which is exactly a 64x32 matmul
Challenges and Questions
We can find the following challenges that arises when tensorizing a wrap level primitives.
It would be useful to discuss possible ways to solve these challenges, for example: