cornell-zhang / allo

Allo: A Programming Model for Composable Accelerator Design
https://cornell-zhang.github.io/allo
Apache License 2.0
122 stars 12 forks source link

[BUG] Excessive Copy Loops #121

Open matth2k opened 9 months ago

matth2k commented 9 months ago

Describe the bug Excessive copy loops are created due to data type conversion of tensors expressed in the linalg dialect.

To Reproduce

def test_vadd():
    from allo import add

    def kernel(A: uint32[N], B: uint32[N]) -> uint32[N]:
        return A + B

    s = allo.customize(kernel)
    print(s.module)

Buggy output

#map = affine_map<(d0) -> (d0)>
module {
  func.func @kernel(%arg0: memref<20xi32>, %arg1: memref<20xi32>) -> memref<20xi32> attributes {itypes = "uu", otypes = "u"} {
    %alloc = memref.alloc() {unsigned} : memref<20xi33>
    linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : memref<20xi32>) outs(%alloc : memref<20xi33>) {
    ^bb0(%in: i32, %out: i33):
      %0 = arith.extui %in : i32 to i33
      linalg.yield %0 : i33
    }
    %alloc_0 = memref.alloc() {unsigned} : memref<20xi33>
    linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg1 : memref<20xi32>) outs(%alloc_0 : memref<20xi33>) {
    ^bb0(%in: i32, %out: i33):
      %0 = arith.extui %in : i32 to i33
      linalg.yield %0 : i33
    }
    %alloc_1 = memref.alloc() : memref<20xi33>
    linalg.add {op_name = "add_0"} ins(%alloc, %alloc_0 : memref<20xi33>, memref<20xi33>) outs(%alloc_1 : memref<20xi33>)
    %alloc_2 = memref.alloc() {unsigned} : memref<20xi32>
    linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%alloc_1 : memref<20xi33>) outs(%alloc_2 : memref<20xi32>) {
    ^bb0(%in: i33, %out: i32):
      %0 = arith.trunci %in : i33 to i32
      linalg.yield %0 : i32
    }
    return %alloc_2 : memref<20xi32>
  }
}

In short, when this is lowered to affine it manifests in excessive copying in the beginning of the program, and our AMC flow is very sensitive to this.

What really should occur is noticing that value that addition is bound to is the same as the input type. So just make add (i32, i32) -> i32 with normal wraparound.

chhzh123 commented 9 months ago

Thanks for bringing this issue up. Allo has a strong type system, and that's why it requires to guarantee the intermediate results will not overflow. I think either we can (1) test the input and output data types and bypass the type extension rule if the types align; or (2) fuse those linalg operations into one.

As a workaround, you can explicitly traverse each element in the arrays using for loops so no linalg operations will be built.

andrewb1999 commented 8 months ago

Your explanation makes sense @chhzh123 but I wonder what optimization HLS is doing to avoid this issue. Does it just fully unroll the copy loops so the extension and truncation can be no cost? If that's the case, maybe we can do a similar optimization in AMC to avoid this issue altogether.

chhzh123 commented 8 months ago

I think Vivado/Vitis HLS only unrolls loops with small loop bounds. Otherwise, we need to explicitly write an unroll pragma to inform HLS. However, unrolling may incur excessive resource usage. The best way I think is still fusing the loops into one.

zhangzhiru commented 8 months ago

I think the main hiccup is that we are lowering to linalg, which is less expressive than imperative programs. So we have to extend both input vectors to int33 first, then add them, and finally truncate back to int32. To clean things up, we really need an extra pass to remove the unnecessary extend and truncate. Another option is insert in a primitive to fuse the loops so another optimization pass at a lower level can finish the job. This is not a good solution though.

matth2k commented 7 months ago

I have the fix implemented within our AMC backend. But I will eventually come up with a more universal solution and submit it as a separate PR to Allo.