iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.56k stars 568 forks source link

Missing propagation for `unpack -> collapse_shape` to `collpase_shape -> unpack`. #17465

Open hanhanW opened 3 months ago

hanhanW commented 3 months ago

It stops the fusion for unpack + consumers. E.g., we should be able to swap unpack and collapse_shape because it is just folding unit dims away.

  func.func @foo(%arg0: tensor<1x1024x1024x16x16xf32>) -> tensor<16384x16384xf32> {
    %0 = tensor.empty() : tensor<1x16384x16384xf32>
    %unpack = tensor.unpack %arg0 outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 16] into %0 : tensor<1x1024x1024x16x16xf32> -> tensor<1x16384x16384xf32>
    %collapsed = tensor.collapse_shape %unpack [[0, 1], [2]] : tensor<1x16384x16384xf32> into tensor<16384x16384xf32>
    %1 = tensor.empty() : tensor<16384x16384xf32>
    %2 = linalg.softmax dimension(1) ins(%collapsed : tensor<16384x16384xf32>) outs(%1 : tensor<16384x16384xf32>) -> tensor<16384x16384xf32>
    return %2 : tensor<16384x16384xf32>
  }
hanhanW commented 3 months ago

It should be done in https://github.com/iree-org/iree/blob/main/compiler/src/iree/compiler/GlobalOptimization/DataLayoutPropagation.cpp

MaheshRavishankar commented 3 months ago

Where is this collapse shape coming from. There might be a uniform way of handling this in the reshape propagation passes later on.

hanhanW commented 3 months ago

I don't know. It is here after set encoding. A sequence of linalg ops are raised to softmax op in GlobalOptimization stage. Are we able to push down reshape ops on named op? It looks not easy to me, so I think we can implement a (unpack, collapse_shape) propagation pattern in this case.

MaheshRavishankar commented 3 months ago

I don't know. It is here after set encoding. A sequence of linalg ops are raised to softmax op in GlobalOptimization stage. Are we able to push down reshape ops on named op? It looks not easy to me, so I think we can implement a (unpack, collapse_shape) propagation pattern in this case.

The propogation patterns are implemented for Linalg ops, but we can add propagation patterns for other ops as well. I'd like to consolidate in one place all the propagation patterns if possible. We can still add those patterns, but we should be able to use them in the reshape propagation passes.