tenstorrent / tt-mlir

Tenstorrent MLIR compiler
https://tenstorrent.github.io/tt-mlir/
Apache License 2.0
51 stars 7 forks source link

Support lowering add and mul through ttir.generic metal backend #468

Closed nsmithtt closed 2 weeks ago

nsmithtt commented 3 weeks ago

The core of this change is generating a loop nest from arith on tensors, consider the following ttir.generic body:

  ^bb0(%arg2: tensor<64x128xf32, #tt.buffer<memref<2x4x!tt.tile<32x32, f32>, #l1_>, alias>>, %arg3, %arg4):
    %8 = arith.addf %arg2, %arg3 : tensor<64x128xf32, #tt.buffer<memref<2x4x!tt.tile<32x32, f32>, #l1_>, alias>>
    "ttir.yield"(%8) : (tensor<64x128xf32, #tt.buffer<memref<2x4x!tt.tile<32x32, f32>, #l1_>, alias>>) -> ()
  })

Into a loop nest using the scf dialect:

  "ttkernel.binary_op_init_common"(%arg2, %arg3, %arg4)
  "ttkernel.add_tiles_init"(%arg2, %arg3)
  %8 = scf.for %arg5 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg6 = %c0_i32) -> (i32)  : i32 {
    %9 = scf.for %arg7 = %c0_i32 to %c4_i32 step %c1_i32 iter_args(%arg8 = %arg6) -> (i32)  : i32 {
      "ttkernel.tile_regs_acquire"() : () -> ()
      "ttkernel.add_tiles"(%arg2, %arg3, %arg8, %arg8, %c0_i32)
      "ttkernel.tile_regs_commit"() : () -> ()
      "ttkernel.tile_regs_wait"() : () -> ()
      "ttkernel.pack_tile"(%c0_i32, %arg4, %arg8)
      "ttkernel.tile_regs_release"() : () -> ()
      %10 = arith.addi %arg8, %c1_i32 : i32
      scf.yield %10 : i32
    }
    scf.yield %9 : i32
  }
  "ttkernel.return"() : () -> ()
nsmithtt commented 2 weeks ago

If someone has a chance to take a look at this review that'd be great!

rpavlovicTT commented 2 weeks ago

Nice work! Looks good to me, minor comments.