iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.87k stars 626 forks source link

Data Tiling: Folding tensor.pack(tensor.unpack) to simplified form #13532

Closed hanhanW closed 1 year ago

hanhanW commented 1 year ago

If they have the same attributes and there are no padding values, we can just fold them away. They become the source of unpack op.

If they have different attributes and there are no padding values, we should be able to decompose the ops and they will become transpose + reshapes + transpose ops. The transpose op can be fused with its producers and consumers; the reshape ops are just metadata ops. E.g.,

%unpack_3762 = tensor.unpack %1462
  inner_dims_pos = [0, 1]
  inner_tiles = [16, 16]
  into %30 : tensor<24x8x16x16xf32> -> tensor<384x128xf32>
%pack_3763 = tensor.pack %unpack_3762
  inner_dims_pos = [0, 1]
  inner_tiles = [16, 1]
  into %37 : tensor<384x128xf32> -> tensor<24x128x16x1xf32>

After decomposition:

%transposed = linalg.transpose
  ins(%arg0 : tensor<24x8x16x16xf32>)
  outs(%1 : tensor<24x16x8x16xf32>)
  permutation = [0, 2, 1, 3]
%collapsed = tensor.collapse_shape %transposed [[0, 1], [2, 3]]
  : tensor<24x16x8x16xf32> into tensor<384x128xf32>
%expanded = tensor.expand_shape %collapsed [[0, 1], [2, 3]]
  : tensor<384x128xf32> into tensor<24x16x128x1xf32>
%transposed_1 = linalg.transpose
  ins(%expanded : tensor<24x16x128x1xf32>)
  outs(%2 : tensor<24x128x16x1xf32>)
  permutation = [0, 2, 1, 3]

One questions is that if we want to do it before forming dispatches. The reshape ops become fusion barrier in this case and we can fuse them with consumers and producers. W/o decomposition, we can form unpack + pack into a dispatch; we can tile and distribute the work. IMO, it's a bit worse because we need an extra kernel launch (if there are producers and consumers).

hanhanW commented 1 year ago

The simplification will result in more dispatch launches because the reshape op becomes a barrier. We should just fuse unfoldable unpack+pack into a dispatch and codegen.