kumasento / polymer

Bridging polyhedral analysis tools to the MLIR framework
MIT License
99 stars 21 forks source link

[Benchmark][AES] handle a mixture of affine and non-affine code #104

Closed kumasento closed 2 years ago

kumasento commented 2 years ago

Input code -


func @encrypt(%arg0: memref<?x16xi32>, %arg1: memref<?xi32>) attributes {llvm.linkage = #llvm.linkage<external>} {
  %c1_i32 = arith.constant 1 : i32
  %c4_i32 = arith.constant 4 : i32
  %c15_i32 = arith.constant 15 : i32
  %c8_i32 = arith.constant 8 : i32
  %c283_i32 = arith.constant 283 : i32
  %0 = memref.alloca() : memref<1024xi32>
  affine.for %arg2 = 1 to 5 {
    affine.for %arg3 = 0 to 16 {
      %1 = affine.load %arg1[%arg3 * 4] : memref<?xi32>
      %2 = arith.shrsi %1, %c4_i32 : i32
      %3 = arith.index_cast %2 : i32 to index
      %4 = arith.andi %1, %c15_i32 : i32
      %5 = arith.index_cast %4 : i32 to index
      %6 = memref.load %arg0[%3, %5] : memref<?x16xi32>
      affine.store %6, %arg1[%arg3 * 4] : memref<?xi32>
    }
    affine.for %arg3 = 0 to 1023 {
      %1 = affine.load %arg1[%arg3] : memref<?xi32>
      %2 = arith.shli %1, %c1_i32 : i32
      affine.store %2, %0[%arg3] : memref<1024xi32>
      %3 = arith.shrsi %2, %c8_i32 : i32
      %4 = arith.cmpi eq, %3, %c1_i32 : i32
      scf.if %4 {
        %10 = arith.xori %2, %c283_i32 : i32
        affine.store %10, %0[%arg3] : memref<1024xi32>
      }
      %5 = affine.load %arg1[%arg3 + 1] : memref<?xi32>
      %6 = arith.shli %5, %c1_i32 : i32
      %7 = arith.xori %5, %6 : i32
      %8 = arith.shrsi %7, %c8_i32 : i32
      %9 = arith.cmpi eq, %8, %c1_i32 : i32
      scf.if %9 {
        %10 = arith.xori %7, %c283_i32 : i32
        %11 = affine.load %0[%arg3] : memref<1024xi32>
        %12 = arith.xori %11, %10 : i32
        affine.store %12, %0[%arg3] : memref<1024xi32>
      } else {
        %10 = affine.load %0[%arg3] : memref<1024xi32>
        %11 = arith.xori %10, %7 : i32
        affine.store %11, %0[%arg3] : memref<1024xi32>
      }
    }
    affine.for %arg3 = 0 to 1024 {
      %1 = affine.load %0[%arg3] : memref<1024xi32>
      affine.store %1, %arg1[%arg3] : memref<?xi32>
    }
  }
  return
}
kumasento commented 2 years ago

After reg2mem -

  func @encrypt(%arg0: memref<?x16xi32>, %arg1: memref<?xi32>) attributes {llvm.linkage = #llvm.linkage<external>} {
    %0 = memref.alloca() {scop.scratchpad} : memref<1xi32>
    %1 = memref.alloca() {scop.scratchpad} : memref<1xi32>
    %c1_i32 = arith.constant 1 : i32
    %c4_i32 = arith.constant 4 : i32
    %c15_i32 = arith.constant 15 : i32
    %c8_i32 = arith.constant 8 : i32
    %c283_i32 = arith.constant 283 : i32
    %2 = memref.alloca() : memref<1024xi32>
    affine.for %arg2 = 1 to 5 {
      affine.for %arg3 = 0 to 16 {
        %3 = affine.load %arg1[%arg3 * 4] : memref<?xi32>
        %4 = arith.shrsi %3, %c4_i32 : i32
        %5 = arith.index_cast %4 : i32 to index
        %6 = arith.andi %3, %c15_i32 : i32
        %7 = arith.index_cast %6 : i32 to index
        %8 = memref.load %arg0[%5, %7] : memref<?x16xi32>
        affine.store %8, %arg1[%arg3 * 4] : memref<?xi32>
      }
      affine.for %arg3 = 0 to 1023 {
        %3 = affine.load %arg1[%arg3] : memref<?xi32>
        %4 = arith.shli %3, %c1_i32 : i32
        affine.store %4, %1[0] : memref<1xi32>
        affine.store %4, %2[%arg3] : memref<1024xi32>
        %5 = arith.shrsi %4, %c8_i32 : i32
        %6 = arith.cmpi eq, %5, %c1_i32 : i32
        scf.if %6 {
          %12 = affine.load %1[0] : memref<1xi32> // <----   crossed the block boundary
          %13 = arith.xori %12, %c283_i32 : i32
          affine.store %13, %2[%arg3] : memref<1024xi32>
        }
        %7 = affine.load %arg1[%arg3 + 1] : memref<?xi32>
        %8 = arith.shli %7, %c1_i32 : i32
        %9 = arith.xori %7, %8 : i32
        affine.store %9, %0[0] : memref<1xi32>
        %10 = arith.shrsi %9, %c8_i32 : i32
        %11 = arith.cmpi eq, %10, %c1_i32 : i32
        scf.if %11 {
          %12 = affine.load %0[0] : memref<1xi32> // <----   crossed the block boundary
          %13 = arith.xori %12, %c283_i32 : i32
          %14 = affine.load %2[%arg3] : memref<1024xi32>
          %15 = arith.xori %14, %13 : i32
          affine.store %15, %2[%arg3] : memref<1024xi32>
        } else {
          %12 = affine.load %0[0] : memref<1xi32> // <----   crossed the block boundary
          %13 = affine.load %2[%arg3] : memref<1024xi32>
          %14 = arith.xori %13, %12 : i32
          affine.store %14, %2[%arg3] : memref<1024xi32>
        }
      }
      affine.for %arg3 = 0 to 1024 {
        %3 = affine.load %2[%arg3] : memref<1024xi32>
        affine.store %3, %arg1[%arg3] : memref<?xi32>
      }
    }
    return
  }