cornell-zhang / hcl-dialect

HeteroCL-MLIR dialect for accelerator design
https://cornell-zhang.github.io/heterocl/index.html
Other
38 stars 17 forks source link

[Backend] Failed to legalize operations for fixed-point arrays #101

Closed chhzh123 closed 2 years ago

chhzh123 commented 2 years ago

See the following conv example.

import heterocl as hcl
import numpy as np
import sys

bs = 4
ic, oc = 6, 16
ih, iw = 8, 8
kh, kw = 3, 3
oh, ow = ih - kh + 1, iw - kw + 1
dtype = hcl.Fixed(24, 12) # hcl.Float()

def test_conv2D_nchw():
    hcl.init(dtype)
    A = hcl.placeholder((bs, ic, ih, iw))
    F = hcl.placeholder((oc, ic, kh, kw))

    def conv(A, F):
        rc = hcl.reduce_axis(0, ic)
        rh = hcl.reduce_axis(0, kh)
        rw = hcl.reduce_axis(0, kw)
        L = ic * kh * kw
        B = hcl.compute(
            (bs, oc, oh, ow),
            lambda n, c, h, w: hcl.sum(
                    A[n, rc, h + rh, w + rw] * F[c, rc, rh, rw],
                    axis=[rc, rh, rw],
                    dtype=dtype,
                ),
            name="B",
            dtype=dtype,
        )
        return B

    s = hcl.create_schedule([A, F], conv)
    print(s.device_module)
    B = conv.B
    LB = s.reuse_at(A, s[B], B.axis[2])
    WB = s.reuse_at(LB, s[B], B.axis[3])
    f = hcl.build(s)

    np_A = np.random.random((bs, ic, ih, iw))
    np_B = np.random.random((oc, ic, kh, kw))
    np_C = np.zeros((bs, oc, oh, ow), dtype="float")

    for n in range(0, bs):
        for c in range(0, oc):
            for y in range(0, oh):
                for x in range(0, ow):
                    for rc in range(0, ic):
                        for rh in range(0, kh):
                            for rw in range(0, kw):
                                np_C[n][c][y][x] += (
                                    np_A[n][rc][y + rh][x + rw]
                                    * np_B[c][rc][rh][rw]
                                )

    hcl_A = hcl.asarray(np_A, dtype=dtype)
    hcl_B = hcl.asarray(np_B, dtype=dtype)
    hcl_C = hcl.asarray(np_C, dtype=dtype)

    f(hcl_A, hcl_B, hcl_C)
    # print(np_C, hcl_C.asnumpy())

    assert np.allclose(np_C, hcl_C.asnumpy())
    print("Passed!")

if __name__ == "__main__":
    test_conv2D_nchw()

Got the error.

loc("-":12:13): error: failed to legalize operation 'affine.store'
loc("-":2:3): error: cannot be converted to LLVM IR: missing `LLVMTranslationDialectInterface` registration for dialect for op: builtin.func
zzzDavid commented 2 years ago

It seems the issue is caused by unrealized_conversion_cast op generated in an intermediate pass:

#set = affine_set<(d0) : (d0 - 2 >= 0)>
module {
  func @top(%arg0: memref<4x6x8x8x!hcl.Fixed<24, 12>>, %arg1: memref<16x6x3x3x!hcl.Fixed<24, 12>>) -> memref<4x16x6x6x!hcl.Fixed<24, 12>> attributes {itypes = "__", otypes = "_"} {
    %0 = memref.alloc() {name = "B"} : memref<4x16x6x6x!hcl.Fixed<24, 12>>
    %1 = memref.alloc() {name = "B_reuse_2"} : memref<6x3x8x!hcl.Fixed<24, 12>>
    %2 = memref.alloc() {name = "B_reuse_3"} : memref<6x3x3x!hcl.Fixed<24, 12>>
    affine.for %arg2 = 0 to 4 {
      affine.for %arg3 = 0 to 16 {
        affine.for %arg4 = 0 to 8 {
          affine.for %arg5 = 0 to 8 {
            affine.for %arg6 = 0 to 6 {
              %3 = affine.load %1[%arg6, 1, %arg5] : memref<6x3x8x!hcl.Fixed<24, 12>>
              affine.store %3, %1[%arg6, 0, %arg5] : memref<6x3x8x!hcl.Fixed<24, 12>>
              %4 = affine.load %1[%arg6, 2, %arg5] : memref<6x3x8x!hcl.Fixed<24, 12>>
              affine.store %4, %1[%arg6, 1, %arg5] : memref<6x3x8x!hcl.Fixed<24, 12>>
              %5 = affine.load %arg0[%arg2, %arg6, %arg4, %arg5] : memref<4x6x8x8x!hcl.Fixed<24, 12>>
              affine.store %5, %1[%arg6, 2, %arg5] : memref<6x3x8x!hcl.Fixed<24, 12>>
            } {spatial}
            affine.if #set(%arg4) {
              affine.for %arg6 = 0 to 6 {
                affine.for %arg7 = 0 to 3 {
                  %3 = affine.load %2[%arg6, %arg7, 1] : memref<6x3x3x!hcl.Fixed<24, 12>>
                  affine.store %3, %2[%arg6, %arg7, 0] : memref<6x3x3x!hcl.Fixed<24, 12>>
                  %4 = affine.load %2[%arg6, %arg7, 2] : memref<6x3x3x!hcl.Fixed<24, 12>>
                  affine.store %4, %2[%arg6, %arg7, 1] : memref<6x3x3x!hcl.Fixed<24, 12>>
                  %5 = affine.load %1[%arg6, %arg7, %arg5] : memref<6x3x8x!hcl.Fixed<24, 12>>
                  affine.store %5, %2[%arg6, %arg7, 2] : memref<6x3x3x!hcl.Fixed<24, 12>>
                } {spatial}
              } {spatial}
              affine.if #set(%arg5) {
                %3 = memref.alloc() {name = "sum_rv"} : memref<1x!hcl.Fixed<24, 12>>
                %c0 = arith.constant 0 : index
                %c0_i32 = arith.constant 0 : i32
                %4 = builtin.unrealized_conversion_cast %c0_i32 : i32 to !hcl.Fixed<24, 12> // THIS cast
                affine.store %4, %3[%c0] {to = "sum_rv"} : memref<1x!hcl.Fixed<24, 12>>
                affine.for %arg6 = 0 to 6 {
                  affine.for %arg7 = 0 to 3 {
                    affine.for %arg8 = 0 to 3 {
                      %6 = affine.load %2[%arg6, %arg7, %arg8] : memref<6x3x3x!hcl.Fixed<24, 12>>
                      %7 = affine.load %arg1[%arg3, %arg6, %arg7, %arg8] {from = "compute_1"} : memref<16x6x3x3x!hcl.Fixed<24, 12>>
                      %8 = "hcl.mul_fixed"(%6, %7) : (!hcl.Fixed<24, 12>, !hcl.Fixed<24, 12>) -> !hcl.Fixed<24, 12>
                      %9 = affine.load %3[%c0] {from = "sum_rv"} : memref<1x!hcl.Fixed<24, 12>>
                      %10 = "hcl.add_fixed"(%8, %9) : (!hcl.Fixed<24, 12>, !hcl.Fixed<24, 12>) -> !hcl.Fixed<24, 12>
                      affine.store %10, %3[%c0] {to = "sum_rv"} : memref<1x!hcl.Fixed<24, 12>>
                    } {loop_name = "rx_2", reduction}
                  } {loop_name = "rx_1", reduction}
                } {loop_name = "rx_0", reduction}
                %5 = affine.load %3[%c0] {from = "sum_rv"} : memref<1x!hcl.Fixed<24, 12>>
                affine.store %5, %0[%arg2, %arg3, %arg4 - 2, %arg5 - 2] : memref<4x16x6x6x!hcl.Fixed<24, 12>>
              }
            }
          } {loop_name = "w"}
        } {loop_name = "h"}
      } {loop_name = "c"}
    } {loop_name = "n", stage_name = "B"}
    return %0 : memref<4x16x6x6x!hcl.Fixed<24, 12>>
  }
}

This unrealized conversion cast is probably generated in LoopTrasformation, it should be changed to int_to_fixed op

chhzh123 commented 2 years ago

Here's the problem. https://github.com/cornell-zhang/hcl-dialect-prototype/blob/848f42af62247e21276be7f7580140fed38c99bd/include/hcl/Bindings/Python/hcl/build_ir.py#L2337-L2339

chhzh123 commented 2 years ago

Also the dtype should change to hcl.Fixed(26,20) in order to pass the test.