Closed makslevental closed 1 year ago
This PR adds syntactical sugar for for loops; things like
for
ten = Tensor.empty((7, 22, 330, 4400), f32) for i, _ in scf_range(0, 10, iter_args=[ten]): y = ten + ten scf_yield(y)
lower to
module { %[[VAL_0:.*]] = tensor.empty() : tensor<7x22x330x4400xf32> %[[VAL_1:.*]] = arith.constant 0 : index %[[VAL_2:.*]] = arith.constant 10 : index %[[VAL_3:.*]] = arith.constant 1 : index %[[VAL_4:.*]] = scf.for %[[VAL_5:.*]] = %[[VAL_1]] to %[[VAL_2]] step %[[VAL_3]] iter_args(%[[VAL_6:.*]] = %[[VAL_0]]) -> (tensor<7x22x330x4400xf32>) { %[[VAL_7:.*]] = arith.addf %[[VAL_6]], %[[VAL_6]] : tensor<7x22x330x4400xf32> scf.yield %[[VAL_7]] : tensor<7x22x330x4400xf32> } }
and things like
ten = Tensor.empty((7, 22, 330, 4400), f32) for i, result in scf_range(0, 10, iter_args=[ten]): y = ten + ten scf_yield(y) return result
module { func.func @test_fold() -> tensor<7x22x330x4400xf32> { %[[VAL_0:.*]] = tensor.empty() : tensor<7x22x330x4400xf32> %[[VAL_1:.*]] = arith.constant 0 : index %[[VAL_2:.*]] = arith.constant 10 : index %[[VAL_3:.*]] = arith.constant 1 : index %[[VAL_4:.*]] = scf.for %[[VAL_5:.*]] = %[[VAL_1]] to %[[VAL_2]] step %[[VAL_3]] iter_args(%[[VAL_6:.*]] = %[[VAL_0]]) -> (tensor<7x22x330x4400xf32>) { %[[VAL_7:.*]] = arith.addf %[[VAL_6]], %[[VAL_6]] : tensor<7x22x330x4400xf32> scf.yield %[[VAL_7]] : tensor<7x22x330x4400xf32> } return %[[VAL_8:.*]] : tensor<7x22x330x4400xf32> } }
This PR adds syntactical sugar for
for
loops; things likelower to
and things like
lower to