cornell-zhang / hcl-dialect

HeteroCL-MLIR dialect for accelerator design
https://cornell-zhang.github.io/heterocl/index.html
Other
37 stars 15 forks source link

set_slice for indices > 64 seem incorrect #170

Closed jcasas00 closed 1 year ago

jcasas00 commented 1 year ago
def test():                
    hcl.init()
    def kernel():
        a96 = hcl.scalar(0, "a96", 'uint96')
        a96.v[ 0:32] = 3735928545
        a96.v[32:64] = 3735928546
        a96.v[64:96] = 3735928547
        #
        hcl.print(a96.v[ 0:32], "h: %d\n") 
        hcl.print(a96.v[32:64], "h: %d\n") 
        #
        hcl.print((a96.v>>64) & 0xffffffff, "h: %d\n")
        hcl.print(a96.v[64:96], "h: %d\n")
        #
        r = hcl.compute((2,), lambda i: 0, dtype=hcl.UInt(32))
        return r
    s = hcl.create_schedule([], kernel)
    hcl_res = hcl.asarray(np.zeros((2,), dtype=np.uint32), dtype=hcl.UInt(32))
    f = hcl.build(s)
    f(hcl_res)

generates:

h: 3735928545 h: 3735928546 h: 3221225472 <- ??? h: 3221225472 <- ???

jcasas00 commented 1 year ago

Hmm ... or is this an issue with print using %d (as %u is not supported yet as I know) and the value has bit 31 set? Or because print %d is really using %f at the llvm level?

zzzDavid commented 1 year ago

The issue is likely caused by SetSliceOp. I removed print and return the results instead, printing the result numpy array has the same issue

    hcl.init()
    def kernel():
        a96 = hcl.scalar(0, "a96", 'uint96')
        a96.v[ 0:32] = 3735928545
        a96.v[32:64] = 3735928546
        a96.v[64:96] = 3735928547
        r = hcl.compute((3,), lambda i: 0, dtype=hcl.UInt(32))
        r[0] = a96.v[ 0:32]
        r[1] = a96.v[32:64]
        r[2] = a96.v[64:96]
        return r
    s = hcl.create_schedule([], kernel)
    hcl_res = hcl.asarray(np.zeros((3,), dtype=np.uint32), dtype=hcl.UInt(32))
    f = hcl.build(s)
    f(hcl_res)
    print(hcl_res.asnumpy())
zzzDavid commented 1 year ago

After some debugging I was able to reproduce this issue with a much smaller example:

def test():
    def kernel():
        a96 = hcl.scalar(0, "a96", 'uint96')
        a96.v[64] = 1
        a96.v[65] = 0
        a96.v[66] = 1
        hcl.print(a96.v[64:96], "%d\n")

    s = hcl.create_schedule([], kernel)
    ir = str(hcl.lower(s))
    with open("./ir.mlir", "w") as f:
        f.write(ir)
    f = hcl.build(s)
    f()

The print result should be 5, but I got 4 instead

Looking into its IR lowered to LLVM dialect:

module {
  llvm.func @malloc(i64) -> !llvm.ptr<i8>
  llvm.mlir.global internal constant @frmt_spec0("%.0f\0A\00") {alignment = 32 : i64}
  llvm.func @printf(!llvm.ptr<i8>, ...) -> i32
  llvm.func @top() attributes {bit, itypes = "", otypes = ""} {
    %0 = llvm.mlir.constant(0 : index) : i64
    %1 = llvm.mlir.constant(1 : index) : i64
    %2 = llvm.mlir.constant(1 : index) : i64
    %3 = llvm.mlir.null : !llvm.ptr<i96>
    %4 = llvm.getelementptr %3[%1] : (!llvm.ptr<i96>, i64) -> !llvm.ptr<i96>
    %5 = llvm.ptrtoint %4 : !llvm.ptr<i96> to i64
    %6 = llvm.call @malloc(%5) : (i64) -> !llvm.ptr<i8>
    %7 = llvm.bitcast %6 : !llvm.ptr<i8> to !llvm.ptr<i96>
    %8 = llvm.mlir.undef : !llvm.struct<(ptr<i96>, ptr<i96>, i64, array<1 x i64>, array<1 x i64>)>
    %9 = llvm.insertvalue %7, %8[0] : !llvm.struct<(ptr<i96>, ptr<i96>, i64, array<1 x i64>, array<1 x i64>)>
    %10 = llvm.insertvalue %7, %9[1] : !llvm.struct<(ptr<i96>, ptr<i96>, i64, array<1 x i64>, array<1 x i64>)>
    %11 = llvm.mlir.constant(0 : index) : i64
    %12 = llvm.insertvalue %11, %10[2] : !llvm.struct<(ptr<i96>, ptr<i96>, i64, array<1 x i64>, array<1 x i64>)>
    %13 = llvm.insertvalue %1, %12[3, 0] : !llvm.struct<(ptr<i96>, ptr<i96>, i64, array<1 x i64>, array<1 x i64>)>
    %14 = llvm.insertvalue %2, %13[4, 0] : !llvm.struct<(ptr<i96>, ptr<i96>, i64, array<1 x i64>, array<1 x i64>)>
    %15 = llvm.mlir.constant(0 : i32) : i32
    %16 = llvm.mlir.constant(0 : i96) : i96
    %17 = llvm.extractvalue %14[1] : !llvm.struct<(ptr<i96>, ptr<i96>, i64, array<1 x i64>, array<1 x i64>)>
    %18 = llvm.getelementptr %17[%0] : (!llvm.ptr<i96>, i64) -> !llvm.ptr<i96>
    llvm.store %16, %18 : !llvm.ptr<i96>
    %19 = llvm.mlir.constant(0 : index) : i64
    %20 = llvm.extractvalue %14[1] : !llvm.struct<(ptr<i96>, ptr<i96>, i64, array<1 x i64>, array<1 x i64>)>
    %21 = llvm.getelementptr %20[%19] : (!llvm.ptr<i96>, i64) -> !llvm.ptr<i96>
    %22 = llvm.load %21 : !llvm.ptr<i96>

    // set bit 64 to 1
    %23 = llvm.mlir.constant(true) {unsigned} : i1
    %24 = llvm.mlir.constant(64 : index) {unsigned} : i64
    %25 = llvm.mlir.constant(1 : i96) : i96
    %26 = llvm.mlir.constant(64 : i96) : i96
    %27 = llvm.shl %25, %26  : i96
    %28 = llvm.mlir.constant(18446744073709551615 : i96) : i96
    %29 = llvm.xor %28, %27  : i96
    %30 = llvm.or %22, %27  : i96
    %31 = llvm.and %22, %29  : i96
    %32 = llvm.mlir.constant(0 : index) : i64
    %33 = llvm.extractvalue %14[1] : !llvm.struct<(ptr<i96>, ptr<i96>, i64, array<1 x i64>, array<1 x i64>)>
    %34 = llvm.getelementptr %33[%32] : (!llvm.ptr<i96>, i64) -> !llvm.ptr<i96>
    llvm.store %30, %34 : !llvm.ptr<i96>

    // set bit 65 to 0
    // somehow this operation affect bit 64 and set it to 0
    %35 = llvm.mlir.constant(0 : index) : i64
    %36 = llvm.extractvalue %14[1] : !llvm.struct<(ptr<i96>, ptr<i96>, i64, array<1 x i64>, array<1 x i64>)>
    %37 = llvm.getelementptr %36[%35] : (!llvm.ptr<i96>, i64) -> !llvm.ptr<i96>
    %38 = llvm.load %37 : !llvm.ptr<i96> // CORRECT
    %41 = llvm.mlir.constant(1 : i96) : i96
    %42 = llvm.mlir.constant(65 : i96) : i96
    %43 = llvm.shl %41, %42  : i96 // %43 is bitmask // should be 2^65
    // %43 = llvm.mlir.constant( 36893488147419103232 : i96) : i96
    %44 = llvm.mlir.constant(18446744073709551615 : i96) : i96
    // %45 = llvm.xor %44, %43  : i96 // %44 is all 1s, %43 is bitmask
    %45 = llvm.mlir.constant(-36893488147419103233 : i96) : i96
    %46 = llvm.or %38, %43  : i96
    // %47 = llvm.and %38, %45  : i96 // %38 is input, %45 is inverted bitmask
    %47 = llvm.and %30, %45  : i96 // %38 is input, %45 is inverted bitmask
    %48 = llvm.mlir.constant(0 : index) : i64
    %49 = llvm.extractvalue %14[1] : !llvm.struct<(ptr<i96>, ptr<i96>, i64, array<1 x i64>, array<1 x i64>)>
    %50 = llvm.getelementptr %49[%48] : (!llvm.ptr<i96>, i64) -> !llvm.ptr<i96>
    llvm.store %47, %50 : !llvm.ptr<i96> // if I remove this, the result is correct

    // set bit 66 to 1
    %51 = llvm.mlir.constant(0 : index) : i64
    %52 = llvm.extractvalue %14[1] : !llvm.struct<(ptr<i96>, ptr<i96>, i64, array<1 x i64>, array<1 x i64>)>
    %53 = llvm.getelementptr %52[%51] : (!llvm.ptr<i96>, i64) -> !llvm.ptr<i96>
    %54 = llvm.load %53 : !llvm.ptr<i96>
    %55 = llvm.mlir.constant(66 : index) {unsigned} : i64
    %56 = llvm.mlir.constant(1 : i96) : i96
    %57 = llvm.mlir.constant(66 : i96) : i96
    %58 = llvm.shl %56, %57  : i96
    %59 = llvm.mlir.constant(18446744073709551615 : i96) : i96
    %60 = llvm.xor %59, %58  : i96
    %61 = llvm.or %54, %58  : i96
    %62 = llvm.and %54, %60  : i96
    %63 = llvm.mlir.constant(0 : index) : i64
    %64 = llvm.extractvalue %14[1] : !llvm.struct<(ptr<i96>, ptr<i96>, i64, array<1 x i64>, array<1 x i64>)>
    %65 = llvm.getelementptr %64[%63] : (!llvm.ptr<i96>, i64) -> !llvm.ptr<i96>
    llvm.store %61, %65 : !llvm.ptr<i96>

    %66 = llvm.mlir.constant(0 : index) : i64
    %67 = llvm.extractvalue %14[1] : !llvm.struct<(ptr<i96>, ptr<i96>, i64, array<1 x i64>, array<1 x i64>)>
    %68 = llvm.getelementptr %67[%66] : (!llvm.ptr<i96>, i64) -> !llvm.ptr<i96>
    %69 = llvm.load %68 : !llvm.ptr<i96>
    %70 = llvm.mlir.constant(95 : index) {unsigned} : i64
    %71 = llvm.mlir.constant(64 : i96) : i96
    %72 = llvm.mlir.constant(95 : i96) : i96
    %73 = llvm.mlir.constant(95 : i96) : i96
    %74 = llvm.sub %73, %72  : i96
    %75 = llvm.shl %69, %74  : i96
    %76 = llvm.lshr %75, %74  : i96
    %77 = llvm.lshr %76, %71  : i96
    %78 = llvm.trunc %77 : i96 to i32
    %79 = llvm.mlir.addressof @frmt_spec0 : !llvm.ptr<array<6 x i8>>
    %80 = llvm.mlir.constant(0 : index) : i64
    %81 = llvm.getelementptr %79[%80, %80] : (!llvm.ptr<array<6 x i8>>, i64, i64) -> !llvm.ptr<i8>
    %82 = llvm.zext %78 : i32 to i64
    %83 = llvm.uitofp %82 : i64 to f64
    %84 = llvm.call @printf(%81, %83) : (!llvm.ptr<i8>, f64) -> i32
    llvm.return
  }
}

Interestingly, it was this line that caused the error. I replaced it with the result it should be, and it worked.

%45 = llvm.xor %44, %43  : i96 // %44 is all 1s, %43 is bitmask

so the issue is caused by llvm.xor, its result is incorrect with 96-bit input. If the input is less than 64-bit, the result is correct

zzzDavid commented 1 year ago

Further debug found that it is the all-one mask went wrong:

    // take the inverse of bitmask
    Value all_one_mask =
        rewriter.create<mlir::arith::ConstantIntOp>(loc, -1, width);

In this case, width is 96, but it only created a constant

%44 = llvm.mlir.constant(18446744073709551615 : i96) : i96

which is 64-bit all-one. We need a 96-bit all-one mask, which is 18446744073709551615

zzzDavid commented 1 year ago

Fixed this by generating all-one bit mask from sign extending 0b1 to any target bit width.