cornell-zhang / hcl-dialect

HeteroCL-MLIR dialect for accelerator design
https://cornell-zhang.github.io/heterocl/index.html
Other
40 stars 17 forks source link

[Backend] Outlined function with const tensor cannot pass CPU simulation #104

Closed chhzh123 closed 2 years ago

chhzh123 commented 2 years ago

See this example, it works totally fine when the data type is float or no .outline() primitive is used.

import heterocl as hcl
import numpy as np
import sys

bs = 4
ic, oc = 6, 16
ih, iw = 8, 8
kh, kw = 3, 3
oh, ow = ih - kh + 1, iw - kw + 1
dtype = hcl.Fixed(26, 20) # hcl.Float()

def test_conv2D_const():
    hcl.init(dtype)
    A = hcl.placeholder((bs, ic, ih, iw))
    np_B = np.random.random((oc, ic, kh, kw))

    def conv(A):
        rc = hcl.reduce_axis(0, ic)
        rh = hcl.reduce_axis(0, kh)
        rw = hcl.reduce_axis(0, kw)
        F = hcl.const_tensor(np_B, "F", dtype)
        B = hcl.compute(
            (bs, oc, oh, ow),
            lambda n, c, h, w: hcl.sum(
                    A[n, rc, h + rh, w + rw] * F[c, rc, rh, rw],
                    axis=[rc, rh, rw],
                    dtype=dtype,
                ),
            name="B",
            dtype=dtype,
        )
        return B

    s = hcl.create_schedule([A], conv)
    B = conv.B
    LB = s.reuse_at(A, s[B], B.axis[2])
    WB = s.reuse_at(LB, s[B], B.axis[3])
    s[B].outline()

    print(hcl.lower(s))
    f = hcl.build(s)

    np_A = np.random.random((bs, ic, ih, iw))
    np_C = np.zeros((bs, oc, oh, ow), dtype="float")

    for n in range(0, bs):
        for c in range(0, oc):
            for y in range(0, oh):
                for x in range(0, ow):
                    for rc in range(0, ic):
                        for rh in range(0, kh):
                            for rw in range(0, kw):
                                np_C[n][c][y][x] += (
                                    np_A[n][rc][y + rh][x + rw]
                                    * np_B[c][rc][rh][rw]
                                )

    hcl_A = hcl.asarray(np_A, dtype=dtype)
    hcl_C = hcl.asarray(np_C, dtype=dtype)

    f(hcl_A, hcl_C)
    # print(np_C, hcl_C.asnumpy())

    assert np.allclose(np_C, hcl_C.asnumpy())
    print("Passed!")

if __name__ == "__main__":
    test_conv2D_const()

However, when .outline() is used, it gives the following error.

python3: /scratch/users/hc676/llvm-project/llvm/lib/IR/Instructions.cpp:508: void llvm::CallInst::init(llvm::FunctionType*, llvm::Value*, llvm::ArrayRef<llvm::Value*>, llvm::ArrayRef<llvm::OperandBundleDefT<llvm::Value*> >, const llvm::Twine&): Assertion `(i >= FTy->getNumParams() || FTy->getParamType(i) == Args[i]->getType()) && "Calling a function with a bad signature!"' failed.
 #0 0x00007f58fb83d92f PrintStackTraceSignalHandler(void*) Signals.cpp:0:0
 #1 0x00007f58fb83b359 SignalHandler(int) Signals.cpp:0:0
 #2 0x00007f5918632630 __restore_rt sigaction.c:0:0
 #3 0x00007f591828b387 raise (/lib64/libc.so.6+0x36387)
 #4 0x00007f591828ca78 abort (/lib64/libc.so.6+0x37a78)
 #5 0x00007f59182841a6 __assert_fail_base (/lib64/libc.so.6+0x2f1a6)

It seems correct from the generated IR.

#set = affine_set<(d0) : (d0 - 2 >= 0)>
module {
  memref.global "private" constant @F : memref<16x6x3x3xi64> = dense<"..."> // omitted here
  func private @Stage_B(%arg0: memref<4x6x8x8x!hcl.Fixed<26, 20>>, %arg1: memref<16x6x3x3x!hcl.Fixed<26, 20>>, %arg2: memref<4x16x6x6x!hcl.Fixed<26, 20>>) attributes {bit, itypes = "___"} {
    %c0 = arith.constant 0 : index
    %0 = memref.alloc() {name = "sum_rv"} : memref<1x!hcl.Fixed<26, 20>>
    %1 = memref.alloc() {name = "B_reuse_3"} : memref<6x3x3x!hcl.Fixed<26, 20>>
    %2 = memref.alloc() {name = "B_reuse_2"} : memref<6x3x8x!hcl.Fixed<26, 20>>
    affine.for %arg3 = 0 to 4 {
      affine.for %arg4 = 0 to 16 {
        affine.for %arg5 = 0 to 8 {
          affine.for %arg6 = 0 to 8 {
            affine.for %arg7 = 0 to 6 {
              %3 = affine.load %2[%arg7, 1, %arg6] : memref<6x3x8x!hcl.Fixed<26, 20>>
              affine.store %3, %2[%arg7, 0, %arg6] : memref<6x3x8x!hcl.Fixed<26, 20>>
              %4 = affine.load %2[%arg7, 2, %arg6] : memref<6x3x8x!hcl.Fixed<26, 20>>
              affine.store %4, %2[%arg7, 1, %arg6] : memref<6x3x8x!hcl.Fixed<26, 20>>
              %5 = affine.load %arg0[%arg3, %arg7, %arg5, %arg6] : memref<4x6x8x8x!hcl.Fixed<26, 20>>
              affine.store %5, %2[%arg7, 2, %arg6] : memref<6x3x8x!hcl.Fixed<26, 20>>
            } {spatial}
            affine.if #set(%arg5) {
              affine.for %arg7 = 0 to 6 {
                affine.for %arg8 = 0 to 3 {
                  %3 = affine.load %1[%arg7, %arg8, 1] : memref<6x3x3x!hcl.Fixed<26, 20>>
                  affine.store %3, %1[%arg7, %arg8, 0] : memref<6x3x3x!hcl.Fixed<26, 20>>
                  %4 = affine.load %1[%arg7, %arg8, 2] : memref<6x3x3x!hcl.Fixed<26, 20>>
                  affine.store %4, %1[%arg7, %arg8, 1] : memref<6x3x3x!hcl.Fixed<26, 20>>
                  %5 = affine.load %2[%arg7, %arg8, %arg6] : memref<6x3x8x!hcl.Fixed<26, 20>>
                  affine.store %5, %1[%arg7, %arg8, 2] : memref<6x3x3x!hcl.Fixed<26, 20>>
                } {spatial}
              } {spatial}
              affine.if #set(%arg6) {
                %c0_i32 = arith.constant 0 : i32
                %3 = hcl.int_to_fixed(%c0_i32) : i32 -> !hcl.Fixed<26, 20>
                affine.store %3, %0[%c0] {to = "sum_rv"} : memref<1x!hcl.Fixed<26, 20>>
                affine.for %arg7 = 0 to 6 {
                  affine.for %arg8 = 0 to 3 {
                    affine.for %arg9 = 0 to 3 {
                      %5 = affine.load %1[%arg7, %arg8, %arg9] : memref<6x3x3x!hcl.Fixed<26, 20>>
                      %6 = affine.load %arg1[%arg4, %arg7, %arg8, %arg9] {from = "const_tensor"} : memref<16x6x3x3x!hcl.Fixed<26, 20>>
                      %7 = "hcl.mul_fixed"(%5, %6) : (!hcl.Fixed<26, 20>, !hcl.Fixed<26, 20>) -> !hcl.Fixed<26, 20>
                      %8 = affine.load %0[%c0] {from = "sum_rv"} : memref<1x!hcl.Fixed<26, 20>>
                      %9 = "hcl.add_fixed"(%7, %8) : (!hcl.Fixed<26, 20>, !hcl.Fixed<26, 20>) -> !hcl.Fixed<26, 20>
                      affine.store %9, %0[%c0] {to = "sum_rv"} : memref<1x!hcl.Fixed<26, 20>>
                    } {loop_name = "rx_2", reduction}
                  } {loop_name = "rx_1", reduction}
                } {loop_name = "rx_0", reduction}
                %4 = affine.load %0[%c0] {from = "sum_rv"} : memref<1x!hcl.Fixed<26, 20>>
                affine.store %4, %arg2[%arg3, %arg4, %arg5 - 2, %arg6 - 2] : memref<4x16x6x6x!hcl.Fixed<26, 20>>
              }
            }
          } {loop_name = "w"}
        } {loop_name = "h"}
      } {loop_name = "c"}
    } {loop_name = "n", stage_name = "B"}
    return
  }
  func @top(%arg0: memref<4x6x8x8x!hcl.Fixed<26, 20>>) -> memref<4x16x6x6x!hcl.Fixed<26, 20>> attributes {itypes = "_", otypes = "_"} {
    %0 = hcl.get_global_fixed @F : memref<16x6x3x3x!hcl.Fixed<26, 20>>
    %1 = memref.alloc() {name = "B"} : memref<4x16x6x6x!hcl.Fixed<26, 20>>
    call @Stage_B(%arg0, %0, %1) : (memref<4x6x8x8x!hcl.Fixed<26, 20>>, memref<16x6x3x3x!hcl.Fixed<26, 20>>, memref<4x16x6x6x!hcl.Fixed<26, 20>>) -> ()
    return %1 : memref<4x16x6x6x!hcl.Fixed<26, 20>>
  }
}
zzzDavid commented 2 years ago

Ah, this is because FixedToInteger pass hasn't implemented transformation on call operation yet. I'll do it.

chhzh123 commented 2 years ago

I rewrote the pass to generate the following code, where get_global_fixed is inside the function (without involving the call operation), but why it still cannot work?

func private @Stage_B(%arg0: memref<4x6x8x8x!hcl.Fixed<26, 20>>, %arg1: memref<4x16x6x6x!hcl.Fixed<26, 20>>) attributes {bit, itypes = "__"} {
    %c0 = arith.constant 0 : index
    %0 = hcl.get_global_fixed @F : memref<16x6x3x3x!hcl.Fixed<26, 20>>
    // more computation
    return
  }
  func @top(%arg0: memref<4x6x8x8x!hcl.Fixed<26, 20>>) -> memref<4x16x6x6x!hcl.Fixed<26, 20>> attributes {itypes = "_", otypes = "_"} {
    %0 = memref.alloc() {name = "B"} : memref<4x16x6x6x!hcl.Fixed<26, 20>>
    call @Stage_B(%arg0, %0) : (memref<4x6x8x8x!hcl.Fixed<26, 20>>, memref<4x16x6x6x!hcl.Fixed<26, 20>>) -> ()
    return %0 : memref<4x16x6x6x!hcl.Fixed<26, 20>>
  }
zzzDavid commented 2 years ago

I don't think this is related to hcl.get_global_fixed, it's the function signature transformation when fixed-point type is involved. Let me fix this now