tlc-pack / tvm-tensorir

Apache License 2.0
8 stars 0 forks source link

[DISCUSS] GPU support and memory hierarchy #51

Closed Hzfengsy closed 4 years ago

Hzfengsy commented 4 years ago

Introduce block_hierarchy

Hardware chips usually have more than one storage and execution hierarchy. As for NVIDIA GPUs, they have GPU blocks(GPU SMs), warp and CUDA cores with global, shared and local memory scope. Each level can access specific memory scope. On the other hand, TIR with blocks also can be hierarchical. A block can only access the buffers which allocate at the same or outer scope. Hence, I would like to make it a one-to-one map. I would talk about it through a GPU gemm example.

// attr [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = 32
// attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 32
block(block_hierarchy="GPU_SM") {
  // attr [iter_var(threadIdx.y, range(min=0, ext=8), threadIdx.y)] thread_extent = 8
  // attr [iter_var(threadIdx.x, range(min=0, ext=8), threadIdx.x)] thread_extent = 8
  block(block_hierarchy="GPU_processor") {
    // attr [iter_var(vy, range(min=0, ext=2), vthread)] virtual_thread = 2
    // attr [iter_var(vx, range(min=0, ext=2), vthread)] virtual_thread = 2
    block(init C.local)
    for (k.outer, 0, 256) {
      block(copy from A to A.shared)
      block(copy from B to B.shared)
      for (k.inner.outer, 0, 8) {
        block(copy from A.shared to A.local)
        block(copy from B.shared to B.local)
        block(update C.local)
      }
    }
    block(copy from C.local to C.global)
  }
}

Pros

Simple way

We still introduce block_hierarchy for some special use(e.g. TensorCore) but allow people to directly bind thread without blockize. The possible problem is that it is hard to do checks during the schedule. However, the good news is we can do checks during BufferFlatten

Questions

I remember that current TVM has a special rule to handle shared memory during bound infer. Maybe I miss some details in that case.

cc @tqchen @spectrometerHBH

spectrometerHBH commented 4 years ago

Overall LGTM. But I am confused about how we reflect warp level things in this hierarchy. Grid(Global)-Block(shared)-Thread(local) looks natural to me.

tqchen commented 4 years ago

Usually it is grid->block->warp->thread

tqchen commented 4 years ago

To expand @Hzfengsy 's question further, the key challenge here is how to design a system of two things:

There are also other factors that we could think about for future applications. In particular, how to represent the scope in a multi-device setting, when a function contains a mix of devices.

def fn(void* [host:cpu] args) {
  float* [gpu:0-global] x = LoadArg(args, 0)
  float* [gpu:1-global] y = LoadArg(args, 1)

  launch [gpu:0-grid] {
     launch [gpu:0-block] {
       // allocate shared memory
      float shared zz[10];
       // code access x
     }
  }
  device_copy_from_to(y, x)
  launch thread-block [gpu:1-block] {
     // code that accesses y
  }
}

The code above contains the elements:

In most of our current usecases, we have not yet dealt with the multiple gpu case. However, we do have the problem of mixed host and GPU code, which we do not offer very comprehensive memory access checking so far.

Such relation also exposes design choices:

In terms of where to put these information, we could:

From the system/IR's point of view C0 is certainly easier for analysis. One potential approach is to "auto-complete" for the users within a given context.

tqchen commented 4 years ago

I do not know a very clear answer to the above question, but I think it might be worth to spend time to think about it, and design an infra for storage and execution scope that fits our future needs

tqchen commented 4 years ago

related https://github.com/apache/incubator-tvm/pull/5190

spectrometerHBH commented 4 years ago

One unmature thought for discussion: Execution scope is a subset of running instances, storage scope denotes an address space accessed by a subset of running instances.

If an execution scope has no corresponding storage scope, then it may be meaningless to introduce such an execution scope. For example, if there is no shared mem/L1 cache, then it makes little sense to introduce block on GPU.

This may imply that storage scope is more essential than execution scope.

tqchen commented 4 years ago

@spectrometerHBH I think you are right.

We do need to take consideration of the fact that an execution scope that has multiple storages scopes attached to it, e.g. in the case of TensorCore WMMA(wmma.lhs, wmma.right, wmma.result), Of course all of them belongs to the level of "warp".

Hzfengsy commented 4 years ago

GPU support contains three major parts: Blockize, Bind and memory scope

Blockize

Block var and loop binding

It may be the most complex part of tir schedule. It is impossible to support the general program, so I just focus on the regular workloads.

  1. All write region must start from a single var. we can pick each of them as block var.
  2. We can make rest loop_var which is used inside as block vars.

Challenge:

  1. Not all block from blockize satisfied binding check.
  2. Need an algorithm to detect the block var iter type.

Buffer Region coverage

Since we need the block exactly to write the buffer region as it declares, it is necessary to check the buffer coverage during blockize.

block hierarchy

As mentioned before, we need to add hierarchy tags to each block. We would provide two API to do it. blockize(loop_sref, hierarchy_tag = None) and set_hierarchy(block_sref, hierarchy_tag)

Move Outside Allocation

If the buffer only be accessed inside the new block, we need to move the allocation under new block.

Bind

I think it is easy since we have done a similar thing vectorize and parallel. The only problem I can see is that how we define a thread_axis. It is strange still using te namespace as tvm.te.thread_axis.`

Memory Scope

As we discussed before, a buffer with a specific memory scope can be only allocated under a block with corresponding hardware hierarchy. It's OK for the final schedule, but it may bring trouble during cache_read/ cache_write. The critical problem is that the intermediate state of a schedule is not a validate schedule for now (te schedule)

spectrometerHBH commented 4 years ago
import numpy as np
import tvm
from tvm import tir
from tvm import te

@tvm.tir.hybrid.script
def matmul(a, b, c):
    C = buffer_bind(c, (2048, 2048), "float32")
    A = buffer_bind(a, (2048, 2048), "float32")
    B = buffer_bind(b, (2048, 2048), "float32")
    reducer = comm_reducer(lambda x, y: x + y, float32(0))

    with block({}, writes=[C[0:2048, 0:2048]], reads=[A[0:2048, 0:2048], B[0:2048, 0:2048]],
               name="root"):
        for i in range(0, 2048):
            for j in range(0, 2048):
                for k in range(0, 2048):
                    with block({vi(0, 2048): i, vj(0, 2048): j, vk(0, 2048, iter_type="reduce"): k},
                               writes=[C[vi:(vi + 1), vj:(vj + 1)]],
                               reads=[C[vi:(vi + 1), vj:(vj + 1)], A[vk:(vk + 1), vi:(vi + 1)],
                                      B[vk:(vk + 1), vj:(vj + 1)]], name="C"):
                        reducer.step(C[vi, vj], A[vk, vi] * B[vk, vj])

n = 2048
device = 'cuda'
ctx = tvm.context(device, 0)
mod = tir.hybrid.create_module({"matmul": matmul})

original_func = mod["matmul"]

a_np = np.random.uniform(size=(n, n)).astype("float32")
b_np = np.random.uniform(size=(n, n)).astype("float32")
a = tvm.nd.array(a_np)
b = tvm.nd.array(b_np)
c = tvm.nd.array(np.zeros((n, n)).astype("float32"))

def build_and_test(func):
    if not ctx.exist:
        print("Skip because %s is not enabled" % device)
        return
    print("Device %s" % device)
    f = tvm.build(s.func, target=device)
    print(tvm.lower(func))
    f(a, b, c)
    tvm.testing.assert_allclose(c.asnumpy(), np.matmul(a.asnumpy(), b.asnumpy()), rtol=1e-5)
    evaluator = f.time_evaluator(f.entry_name, ctx, number=1)
    return evaluator(a, b, c).mean

scale = 8
num_thread = 8
block_factor = scale * num_thread

block_x = te.thread_axis("blockIdx.x")
thread_x = te.thread_axis((0, num_thread), "threadIdx.x")
block_y = te.thread_axis("blockIdx.y")
thread_y = te.thread_axis((0, num_thread), "threadIdx.y")
thread_xz = te.thread_axis((0, 2), "vthread", name="vx")
thread_yz = te.thread_axis((0, 2), "vthread", name="vy")

s = tir.create_schedule(original_func)
A = original_func.buffer_map[original_func.params[0]]
B = original_func.buffer_map[original_func.params[1]]
C = original_func.buffer_map[original_func.params[2]]

C_block = s.get_block("C")

AA = s.cache_read(A, "shared")
BB = s.cache_read(B, "shared")
AL = s.cache_read(AA.writes[0].buffer, "local")
BL = s.cache_read(BB.writes[0].buffer, "local")
CC = s.cache_write(C, "local")

y, x = s.get_axes(C_block)
by, yi = s.split(y, block_factor)
bx, xi = s.split(x, block_factor)
s.reorder(by, bx, yi, xi)

s.bind(by, block_y)
s.bind(bx, block_x)

tyz, yi = s.split(yi, nparts=2)
ty, yi = s.split(yi, nparts=num_thread)
txz, xi = s.split(xi, nparts=2)
tx, xi = s.split(xi, nparts=num_thread)
s.reorder(tyz, txz, ty, tx, yi, xi)

s.bind(tyz, thread_yz)
s.bind(txz, thread_xz)
s.bind(ty, thread_y)
s.bind(tx, thread_x)

s.compute_at(CC, tx)

y, x, k = s.get_axes(CC)[-3:]
ko, ki = s.split(k, factor=8)
kt, ki = s.split(ki, factor=1)
s.reorder(ko, kt, ki, y, x)

s.compute_at(AL, kt)
s.compute_at(BL, kt)
s.compute_at(AA, ko)
s.compute_at(BB, ko)

s.decompose_reduction(CC, tx)

print(tir.hybrid.ashybrid(s.func))
build_and_test(s.func)
def func(a, b, c):
    B = buffer_bind(b, (2048, 2048), "float32")
    C = buffer_bind(c, (2048, 2048), "float32")
    A = buffer_bind(a, (2048, 2048), "float32")
    with block({}, writes=[C[0:2048, 0:2048]], reads=[A[0:2048, 0:2048], B[0:2048, 0:2048]], name="root"):
        A_shared = buffer_allocate((2048, 2048), "float32", "shared")
        B_shared = buffer_allocate((2048, 2048), "float32", "shared")
        A_shared_local = buffer_allocate((2048, 2048), "float32", "local")
        B_shared_local = buffer_allocate((2048, 2048), "float32", "local")
        C_local = buffer_allocate((2048, 2048), "float32", "local")
        for ax0_outer in range(0, 32, annotation = {"loop_type":"blockIdx.y"}):
            for ax1_outer in range(0, 32, annotation = {"loop_type":"blockIdx.x"}):
                for ax0_inner_outer in range(0, 2, annotation = {"loop_type":"vthread"}):
                    for ax1_inner_outer in range(0, 2, annotation = {"loop_type":"vthread"}):
                        for ax0_inner_inner_outer in range(0, 8, annotation = {"loop_type":"threadIdx.y"}):
                            for ax1_inner_inner_outer in range(0, 8, annotation = {"loop_type":"threadIdx.x"}):
                                for ax0_init in range(0, 4):
                                    for ax1_init in range(0, 4):
                                        with block({vi_init(0, 2048):((((ax0_outer*64) + (ax0_inner_outer*32)) + (ax0_inner_inner_outer*4)) + ax0_init), vj_init(0, 2048):((((ax1_outer*64) + (ax1_inner_outer*32)) + (ax1_inner_inner_outer*4)) + ax1_init)}, writes=[C_local[vi_init:(vi_init + 1), vj_init:(vj_init + 1)]], reads=[], name="C_init"):
                                            C_local[vi_init, vj_init] = float32(0)
                                for ax2_outer in range(0, 256):
                                    for ax0 in range(0, 8):
                                        for ax1 in range(0, 4):
                                            with block({v0(0, 2048):((ax2_outer*8) + ax0), v1(0, 2048):((((ax1_outer*64) + (ax1_inner_outer*32)) + (ax1_inner_inner_outer*4)) + ax1)}, writes=[B_shared[v0:(v0 + 1), v1:(v1 + 1)]], reads=[B[v0:(v0 + 1), v1:(v1 + 1)]], name=""):
                                                B_shared[v0, v1] = B[v0, v1]
                                    for ax0 in range(0, 8):
                                        for ax1 in range(0, 4):
                                            with block({v0(0, 2048):((ax2_outer*8) + ax0), v1(0, 2048):((((ax0_outer*64) + (ax0_inner_outer*32)) + (ax0_inner_inner_outer*4)) + ax1)}, writes=[A_shared[v0:(v0 + 1), v1:(v1 + 1)]], reads=[A[v0:(v0 + 1), v1:(v1 + 1)]], name=""):
                                                A_shared[v0, v1] = A[v0, v1]
                                    for ax2_inner_outer in range(0, 8):
                                        for ax1 in range(0, 4):
                                            with block({v0(0, 2048):((ax2_outer*8) + ax2_inner_outer), v1(0, 2048):((((ax1_outer*64) + (ax1_inner_outer*32)) + (ax1_inner_inner_outer*4)) + ax1)}, writes=[B_shared_local[v0:(v0 + 1), v1:(v1 + 1)]], reads=[B_shared[v0:(v0 + 1), v1:(v1 + 1)]], name=""):
                                                B_shared_local[v0, v1] = B_shared[v0, v1]
                                        for ax1 in range(0, 4):
                                            with block({v0(0, 2048):((ax2_outer*8) + ax2_inner_outer), v1(0, 2048):((((ax0_outer*64) + (ax0_inner_outer*32)) + (ax0_inner_inner_outer*4)) + ax1)}, writes=[A_shared_local[v0:(v0 + 1), v1:(v1 + 1)]], reads=[A_shared[v0:(v0 + 1), v1:(v1 + 1)]], name=""):
                                                A_shared_local[v0, v1] = A_shared[v0, v1]
                                        for ax2_inner_inner in range(0, 1):
                                            for ax0 in range(0, 4):
                                                for ax1 in range(0, 4):
                                                    with block({vi(0, 2048):((((ax0_outer*64) + (ax0_inner_outer*32)) + (ax0_inner_inner_outer*4)) + ax0), vj(0, 2048):((((ax1_outer*64) + (ax1_inner_outer*32)) + (ax1_inner_inner_outer*4)) + ax1), vk(0, 2048, iter_type="reduce"):((ax2_outer*8) + (ax2_inner_outer + ax2_inner_inner))}, writes=[C_local[vi:(vi + 1), vj:(vj + 1)]], reads=[C_local[vi:(vi + 1), vj:(vj + 1)], A_shared_local[vk:(vk + 1), vi:(vi + 1)], B_shared_local[vk:(vk + 1), vj:(vj + 1)]], name="C_update"):
                                                        C_local[vi, vj] = (C_local[vi, vj] + (A_shared_local[vk, vi]*B_shared_local[vk, vj]))
                                for ax0_inner_inner_inner in range(0, 4):
                                    for ax1_inner_inner_inner in range(0, 4):
                                        with block({v0(0, 2048):((ax0_outer*64) + ((ax0_inner_outer*32) + ((ax0_inner_inner_outer*4) + ax0_inner_inner_inner))), v1(0, 2048):((ax1_outer*64) + ((ax1_inner_outer*32) + ((ax1_inner_inner_outer*4) + ax1_inner_inner_inner)))}, writes=[C[v0:(v0 + 1), v1:(v1 + 1)]], reads=[C_local[v0:(v0 + 1), v1:(v1 + 1)]], name=""):
                                            C[v0, v1] = C_local[v0, v1]
PrimFunc([a, b, c]) attrs={"tir.noalias": (bool)1, "global_symbol": "main"} {
  // attr [iter_var(ax0_outer, range(min=0, ext=32), blockIdx.y)] thread_extent = 32
  // attr [C_local] storage_scope = "local"
  allocate C_local[float32 * 64]
  // attr [B_shared] storage_scope = "shared"
  allocate B_shared[float32 * 512]
  // attr [A_shared] storage_scope = "shared"
  allocate A_shared[float32 * 512]
  // attr [B_shared_local] storage_scope = "local"
  allocate B_shared_local[float32 * 8]
  // attr [A_shared_local] storage_scope = "local"
  allocate A_shared_local[float32 * 8]
  // attr [iter_var(ax1_outer, range(min=0, ext=32), blockIdx.x)] thread_extent = 32
  // attr [iter_var(ax0_inner_inner_outer, range(min=0, ext=8), threadIdx.y)] thread_extent = 8
  // attr [iter_var(ax1_inner_inner_outer, range(min=0, ext=8), threadIdx.x)] thread_extent = 8
  for (ax0_init, 0, 4) {
    for (ax1_init, 0, 4) {
      C_local[((ax0_init*4) + ax1_init)] = 0f
      C_local[(((ax0_init*4) + ax1_init) + 32)] = 0f
      C_local[(((ax0_init*4) + ax1_init) + 16)] = 0f
      C_local[(((ax0_init*4) + ax1_init) + 48)] = 0f
    }
  }
  for (ax2_outer, 0, 256) {
    for (ax0, 0, 8) {
      for (ax1, 0, 4) {
        B_shared[(((ax0*32) + (ax1_inner_inner_outer*4)) + ax1)] = B[(((((ax2_outer*16384) + (ax0*2048)) + (ax1_outer*64)) + (ax1_inner_inner_outer*4)) + ax1)]
        B_shared[((((ax0*32) + (ax1_inner_inner_outer*4)) + ax1) + 256)] = B[((((((ax2_outer*16384) + (ax0*2048)) + (ax1_outer*64)) + (ax1_inner_inner_outer*4)) + ax1) + 32)]
      }
    }
    for (ax0, 0, 8) {
      for (ax1, 0, 4) {
        A_shared[(((ax0*32) + (ax0_inner_inner_outer*4)) + ax1)] = A[(((((ax2_outer*16384) + (ax0*2048)) + (ax0_outer*64)) + (ax0_inner_inner_outer*4)) + ax1)]
        A_shared[((((ax0*32) + (ax0_inner_inner_outer*4)) + ax1) + 256)] = A[((((((ax2_outer*16384) + (ax0*2048)) + (ax0_outer*64)) + (ax0_inner_inner_outer*4)) + ax1) + 32)]
      }
    }
    for (ax2_inner_outer, 0, 8) {
      for (ax1, 0, 4) {
        B_shared_local[ax1] = B_shared[(((ax2_inner_outer*32) + (ax1_inner_inner_outer*4)) + ax1)]
        B_shared_local[(ax1 + 4)] = B_shared[((((ax2_inner_outer*32) + (ax1_inner_inner_outer*4)) + ax1) + 256)]
      }
      for (ax1, 0, 4) {
        A_shared_local[ax1] = A_shared[(((ax2_inner_outer*32) + (ax0_inner_inner_outer*4)) + ax1)]
        A_shared_local[(ax1 + 4)] = A_shared[((((ax2_inner_outer*32) + (ax0_inner_inner_outer*4)) + ax1) + 256)]
      }
      for (ax2_inner_inner, 0, 1) {
        for (ax0, 0, 4) {
          for (ax1, 0, 4) {
            C_local[((ax0*4) + ax1)] = (C_local[((ax0*4) + ax1)] + (A_shared_local[ax0]*B_shared_local[ax1]))
            C_local[(((ax0*4) + ax1) + 32)] = (C_local[(((ax0*4) + ax1) + 32)] + (A_shared_local[(ax0 + 4)]*B_shared_local[ax1]))
            C_local[(((ax0*4) + ax1) + 16)] = (C_local[(((ax0*4) + ax1) + 16)] + (A_shared_local[ax0]*B_shared_local[(ax1 + 4)]))
            C_local[(((ax0*4) + ax1) + 48)] = (C_local[(((ax0*4) + ax1) + 48)] + (A_shared_local[(ax0 + 4)]*B_shared_local[(ax1 + 4)]))
          }
        }
      }
    }
  }
  for (ax0_inner_inner_inner, 0, 4) {
    for (ax1_inner_inner_inner, 0, 4) {
      C[((((((ax0_outer*131072) + (ax0_inner_inner_outer*8192)) + (ax0_inner_inner_inner*2048)) + (ax1_outer*64)) + (ax1_inner_inner_outer*4)) + ax1_inner_inner_inner)] = C_local[((ax0_inner_inner_inner*4) + ax1_inner_inner_inner)]
      C[(((((((ax0_outer*131072) + (ax0_inner_inner_outer*8192)) + (ax0_inner_inner_inner*2048)) + (ax1_outer*64)) + (ax1_inner_inner_outer*4)) + ax1_inner_inner_inner) + 65536)] = C_local[(((ax0_inner_inner_inner*4) + ax1_inner_inner_inner) + 32)]
      C[(((((((ax0_outer*131072) + (ax0_inner_inner_outer*8192)) + (ax0_inner_inner_inner*2048)) + (ax1_outer*64)) + (ax1_inner_inner_outer*4)) + ax1_inner_inner_inner) + 32)] = C_local[(((ax0_inner_inner_inner*4) + ax1_inner_inner_inner) + 16)]
      C[(((((((ax0_outer*131072) + (ax0_inner_inner_outer*8192)) + (ax0_inner_inner_inner*2048)) + (ax1_outer*64)) + (ax1_inner_inner_outer*4)) + ax1_inner_inner_inner) + 65568)] = C_local[(((ax0_inner_inner_inner*4) + ax1_inner_inner_inner) + 48)]
    }
  }
}