iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.56k stars 573 forks source link

Better representation of generic reshapes when possible #18058

Open MaheshRavishankar opened 1 month ago

MaheshRavishankar commented 1 month ago

Some of the models (like SDXL) from torch seem to contain these types of patterns

%collapsed = tensor.collapse_shape %input [[0], [1, 2], [3]] : tensor<2x32x10x16384xf32> into tensor<2x320x16384xf32>
%expanded = tensor.expand_shape %output [[0], [1], [2, 3]] output_shape [2, 320, 128, 128] : tensor<2x320x16384xf32> into tensor<2x320x128x128xf32>

What this is trying to do is go from a shape of tensor<2x32x10x16384xf32> to a shape of tensor<2x320x128x128xf32>. The fusion heuristics that essentially relies on propagating expand_shapes up and collapse_shapes gets stuck here and we miss out on some fusion opportunities. A more canonical representation of this is

%expanded = tensor.expand_shape %input [[0], [1], [2], [3, 4]] output_shape [2, 32, 10, 128, 128] : tensor<2x32x10x16384xf32> into tensor<2x32x10x128x128xf32> 
%collapsed = tensor.collapse_shape %output [[0], [1, 2], [3, 4]] %expanded : tensor<2x32x10x128x128xf32> into tensor<2x32x10x128x128xf32>

Then the propagation should be able to fuse better.

MaheshRavishankar commented 1 month ago

@Max191 is this something that you already had a pattern to fix? I think I remember you saying something like this but didnt connect it in my head.

Max191 commented 1 month ago

Yes, this PR fixes it: https://github.com/llvm/llvm-project/pull/94637

A more canonical representation of this is

I don't know if it is more canonical to have one ordering over the other, but the reshape propagation patterns should not be blocked by cases like this. Ideally, we should also have patterns for the inverse case of expand_shape->collapse_shape.

I was supposed to refactor this to Tensor dialect, but the PR got low on my priority list, so I haven't touched it in a while.