cornell-zhang / hcl-dialect

HeteroCL-MLIR dialect for accelerator design
https://cornell-zhang.github.io/heterocl/index.html
Other
40 stars 17 forks source link

[Pass] bconv with reuse_at generates incorrect results #107

Closed chhzh123 closed 2 years ago

chhzh123 commented 2 years ago

Finally located the problem -- the condition in the ifOp that uses the referenced induction variable also needs to be updated, since the iteration space has been changed to the output space.

#set0 = affine_set<(d0) : (d0 - 2 >= 0)>
#set1 = affine_set<(d0, d1, d2, d3) : (d0 + d1 >= 0, -(d0 + d1) + 7 >= 0, d2 + d3 >= 0, -(d2 + d3) + 7 >= 0)>
module {
  func @top(%arg0: memref<4x1x8x8xi16>, %arg1: memref<32x1x3x3xi16>) -> memref<4x32x6x6xi32> attributes {itypes = "uu", otypes = "s"} {
    %c0 = arith.constant 0 : index
    %0 = memref.alloc() {name = "B"} : memref<4x32x6x6xi32>
    %1 = memref.alloc() {name = "B_reuse_2"} : memref<3x8xi16>
    affine.for %arg2 = 0 to 4 {
      affine.for %arg3 = 0 to 32 {
        affine.for %arg4 = 0 to 8 {
          affine.for %arg5 = 0 to 8 {
            %2 = affine.load %1[1, %arg5] : memref<3x8xi16>
            affine.store %2, %1[0, %arg5] : memref<3x8xi16>
            %3 = affine.load %1[2, %arg5] : memref<3x8xi16>
            affine.store %3, %1[1, %arg5] : memref<3x8xi16>
            %4 = affine.load %arg0[%arg2, %c0, %arg4, %arg5] : memref<4x1x8x8xi16>
            affine.store %4, %1[2, %arg5] : memref<3x8xi16>
          } {spatial}
          affine.if #set0(%arg4) {
            affine.for %arg5 = 0 to 6 {
              %2 = memref.alloc() {name = "sum_rv"} : memref<1xi32>
              %c0_i32 = arith.constant 0 : i32
              affine.store %c0_i32, %2[%c0] {to = "sum_rv"} : memref<1xi32>
              affine.for %arg6 = 0 to 3 {
                affine.for %arg7 = 0 to 3 {
                  %4 = affine.if #set1(%arg5, %arg7, %arg4, %arg6) -> i32 {
                    %c16_i32 = arith.constant 16 : i32
                    %7 = arith.extsi %c16_i32 : i32 to i128
                    %8 = affine.load %1[%arg6, %arg5 + %arg7] : memref<3x8xi16>
                    %9 = affine.load %arg1[%arg3, %c0, %arg6, %arg7] {from = "compute_1", unsigned} : memref<32x1x3x3xi16>
                    %10 = arith.xori %8, %9 {unsigned} : i16
                    %c1_i32 = arith.constant 1 : i32
                    %11 = arith.trunci %c1_i32 {unsigned} : i32 to i16
                    %12 = arith.shrui %10, %11 {unsigned} : i16
                    %c6148914691236517205_i64 = arith.constant 6148914691236517205 : i64
                    %13 = arith.trunci %c6148914691236517205_i64 {unsigned} : i64 to i16
                    %14 = arith.andi %12, %13 {unsigned} : i16
                    %15 = arith.subi %10, %14 {unsigned} : i16
                    %c3689348814741910323_i64 = arith.constant 3689348814741910323 : i64
                    %16 = arith.trunci %c3689348814741910323_i64 {unsigned} : i64 to i16
                    %17 = arith.andi %15, %16 {unsigned} : i16
                    %c2_i32 = arith.constant 2 : i32
                    %18 = arith.trunci %c2_i32 {unsigned} : i32 to i16
                    %19 = arith.shrui %15, %18 {unsigned} : i16
                    %20 = arith.andi %19, %16 {unsigned} : i16
                    %21 = arith.addi %17, %20 {unsigned} : i16
                    %c4_i32 = arith.constant 4 : i32
                    %22 = arith.trunci %c4_i32 {unsigned} : i32 to i16
                    %23 = arith.shrui %21, %22 {unsigned} : i16
                    %24 = arith.addi %21, %23 {unsigned} : i16
                    %c1085102592571150095_i64 = arith.constant 1085102592571150095 : i64
                    %25 = arith.trunci %c1085102592571150095_i64 {unsigned} : i64 to i16
                    %26 = arith.andi %24, %25 {unsigned} : i16
                    %27 = arith.extui %26 : i16 to i64
                    %c72340172838076673_i64 = arith.constant 72340172838076673 : i64
                    %28 = arith.muli %27, %c72340172838076673_i64 : i64
                    %c56_i32 = arith.constant 56 : i32
                    %29 = arith.extsi %c56_i32 : i32 to i64
                    %30 = arith.shrui %28, %29 : i64
                    %31 = arith.extsi %30 : i64 to i128
                    %32 = arith.extsi %c1_i32 : i32 to i64
                    %33 = arith.extsi %32 : i64 to i128
                    %34 = arith.shli %31, %33 : i128
                    %35 = arith.subi %7, %34 : i128
                    %36 = arith.trunci %35 : i128 to i32
                    affine.yield %36 : i32
                  } else {
                    affine.yield %c0_i32 : i32
                  }
                  %5 = affine.load %2[%c0] {from = "sum_rv"} : memref<1xi32>
                  %6 = arith.addi %4, %5 : i32
                  affine.store %6, %2[%c0] {to = "sum_rv"} : memref<1xi32>
                } {loop_name = "B_rx", reduction}
              } {loop_name = "B_ry", reduction}
              %3 = affine.load %2[%c0] {from = "sum_rv"} : memref<1xi32>
              affine.store %3, %0[%arg2, %arg3, %arg4 - 2, %arg5] : memref<4x32x6x6xi32>
            } {loop_name = "xx"}
          }
        } {loop_name = "yy"}
      } {loop_name = "ff"}
    } {loop_name = "nn", stage_name = "B"}
    return %0 : memref<4x32x6x6xi32>
  }
}