iree-org / iree-llvm-sandbox

A sandbox for quick iteration and experimentation on projects related to IREE, MLIR, and LLVM
Apache License 2.0
54 stars 31 forks source link

[Indexing] For-loop sugar #741

Closed makslevental closed 1 year ago

makslevental commented 1 year ago

This PR adds syntactical sugar for for loops; things like

    ten = Tensor.empty((7, 22, 330, 4400), f32)
    for i, _ in scf_range(0, 10, iter_args=[ten]):
      y = ten + ten
      scf_yield(y)

lower to

module {
  %[[VAL_0:.*]] = tensor.empty() : tensor<7x22x330x4400xf32>
  %[[VAL_1:.*]] = arith.constant 0 : index
  %[[VAL_2:.*]] = arith.constant 10 : index
  %[[VAL_3:.*]] = arith.constant 1 : index
  %[[VAL_4:.*]] = scf.for %[[VAL_5:.*]] = %[[VAL_1]] to %[[VAL_2]] step %[[VAL_3]] iter_args(%[[VAL_6:.*]] = %[[VAL_0]]) -> (tensor<7x22x330x4400xf32>) {
    %[[VAL_7:.*]] = arith.addf %[[VAL_6]], %[[VAL_6]] : tensor<7x22x330x4400xf32>
    scf.yield %[[VAL_7]] : tensor<7x22x330x4400xf32>
  }
}

and things like

      ten = Tensor.empty((7, 22, 330, 4400), f32)
      for i, result in scf_range(0, 10, iter_args=[ten]):
        y = ten + ten
        scf_yield(y)
      return result

lower to

module {
  func.func @test_fold() -> tensor<7x22x330x4400xf32> {
    %[[VAL_0:.*]] = tensor.empty() : tensor<7x22x330x4400xf32>
    %[[VAL_1:.*]] = arith.constant 0 : index
    %[[VAL_2:.*]] = arith.constant 10 : index
    %[[VAL_3:.*]] = arith.constant 1 : index
    %[[VAL_4:.*]] = scf.for %[[VAL_5:.*]] = %[[VAL_1]] to %[[VAL_2]] step %[[VAL_3]] iter_args(%[[VAL_6:.*]] = %[[VAL_0]]) -> (tensor<7x22x330x4400xf32>) {
      %[[VAL_7:.*]] = arith.addf %[[VAL_6]], %[[VAL_6]] : tensor<7x22x330x4400xf32>
      scf.yield %[[VAL_7]] : tensor<7x22x330x4400xf32>
    }
    return %[[VAL_8:.*]] : tensor<7x22x330x4400xf32>
  }
}