iree-org / iree-llvm-sandbox

A sandbox for quick iteration and experimentation on projects related to IREE, MLIR, and LLVM
Apache License 2.0
54 stars 31 forks source link

[Indexing] Arbitrary slicing #734

Closed makslevental closed 1 year ago

makslevental commented 1 year ago

Depends on https://github.com/iree-org/iree-llvm-sandbox/pull/740.

This PR implements the remainder of slicing, i.e. with arbitrary start, stop, and step:

    ten = Tensor.empty((7, 22, 330, 4400), f32)
    # CHECK: Tensor(%[[TEN:.*]], tensor<7x22x330x4400xf32>)
    print(ten)

    w = ten[:, ::2]
    # CHECK: %[[ARA:.*]] = arith.constant dense<[[0], [2], [4], [6], [8], [10], [12], [14], [16], [18], [20]]> : tensor<11x1xindex>
    print(w.owner.operands[1].owner)
    # CHECK: %{{.*}} = indexing.gather %[[TEN]][%[[ARA]]] gather_dims([1]) unique : (tensor<7x22x330x4400xf32>, tensor<11x1xindex>) -> tensor<11x7x330x4400xf32>
    print(w.owner)

    w = ten[:, ::2, ::30]
    # CHECK: %[[ARA:.*]] = arith.constant dense<[[0, 0], [2, 30], [4, 60], [6, 90], [8, 120], [10, 150], [12, 180], [14, 210], [16, 240], [18, 270], [20, 300]]> : tensor<11x2xindex>
    print(w.owner.operands[1].owner)
    # CHECK: %{{.*}} = indexing.gather %[[TEN]][%[[ARA]]] gather_dims([1, 2]) unique : (tensor<7x22x330x4400xf32>, tensor<11x2xindex>) -> tensor<11x7x4400xf32>
    print(w.owner)

    w = ten[:, ::2, ::30, ::400]
    # CHECK: %[[ARA:.*]] = arith.constant dense<[[0, 0, 0], ..., [20, 300, 4000]]> : tensor<11x3xindex>
    print(w.owner.operands[1].owner)
    # CHECK: %{{.*}} = indexing.gather %[[TEN]][%[[ARA]]] gather_dims([1, 2, 3]) unique : (tensor<7x22x330x4400xf32>, tensor<11x3xindex>) -> tensor<11x7xf32>
    print(w.owner)

    w = ten[:, :, 100:200:5, 1000:2000:50]
    # CHECK: %[[ARA:.*]] = arith.constant dense<[[100, 1000], [105, 1050], ..., [195, 1950]]> : tensor<20x2xindex>
    print(w.owner.operands[1].owner)
    # CHECK: %{{.*}} = indexing.gather %[[TEN]][%[[ARA]]] gather_dims([2, 3]) unique : (tensor<7x22x330x4400xf32>, tensor<20x2xindex>) -> tensor<20x7x22xf32>
    print(w.owner)

including with non-constant operands for the slice object:

      ten = Tensor.empty((7, 22, 330, 4400), f32)
      one = Scalar(1, dtype=index, fold=False)
      start = 100 * one
      stop = 200 * one
      step = 5 * one

      w = ten[:, :, start:stop:step]
      # CHECK: %[[VAL_8:.*]] = "indexing.arange"(%[[VAL_3]], %[[VAL_5]], %[[VAL_7]]) {operand_segment_sizes = array<i32: 1, 1, 1>} : (index, index, index) -> tensor<?x1xindex>
      print(w.owner.operands[1].owner)
      # CHECK: %[[VAL_9:.*]] = "indexing.gather"(%[[VAL_0]], %[[VAL_8]]) {gather_dims = array<i64: 2>, unique} : (tensor<7x22x330x4400xf32>, tensor<?x1xindex>) -> tensor<?x7x22x4400xf32>
      print(w.owner)

      w = ten[:, :, start:stop:5]
      # CHECK: %[[VAL_10:.*]] = "arith.constant"() <{value = 5 : index}> : () -> index
      print(w.owner.operands[1].owner.operands[2])
      # CHECK: %[[VAL_11:.*]] = "indexing.arange"(%[[VAL_3]], %[[VAL_5]], %[[VAL_10]]) {operand_segment_sizes = array<i32: 1, 1, 1>} : (index, index, index) -> tensor<?x1xindex>
      print(w.owner.operands[1].owner)
      # CHECK: %[[VAL_12:.*]] = "indexing.gather"(%[[VAL_0]], %[[VAL_11]]) {gather_dims = array<i64: 2>, unique} : (tensor<7x22x330x4400xf32>, tensor<?x1xindex>) -> tensor<?x7x22x4400xf32>
      print(w.owner)
nicolasvasilache commented 1 year ago

To be consistent with tensor.pad, could we invert the attribute and call it "nofold" ?

Also, for consistency, the default behavior of tensor.pad is to fold; i.e. this is an optional UnitAttr and its presence disables folding.

makslevental commented 1 year ago

To be consistent with tensor.pad, could we invert the attribute and call it "nofold" ?

Also, for consistency, the default behavior of tensor.pad is to fold; i.e. this is an optional UnitAttr and its presence disables folding.

Are we sure that we want this? The reason I went back and made it not fold by default is because folding the arange produces a tensor of indices nx1xindex which erases/loses the structure of the arange (start, stop, step).