Closed hanhanW closed 4 months ago
Its hard to see at this level. The elementwise op fusion is supposed to do the reshape -> elementwise op
-> elementwise op -> reshape
transformation. Need to look at the whole model though to see why this didnt happen. IR before elementwise fusion would be the place to start. @hanhanW do you have that handy?
Yes, I have the IR before and after elementwise fusion: https://gist.githubusercontent.com/hanhanW/0df90c3751be3df5ce59515c36d3ad79/raw
Looking into it
The folding didn't happen because the indexing map for the reshaped input has a transpose. The folding has a logic that checks if the dim sequence is preserved or not, which seems to be too conservative.
Ok, thanks for looking into it. I dont think the requirement is conservative. Here is the IR I think before dispatch region formation
%11 = tensor.expand_shape [[0], [1, 2]] : tensor<128x384xf32> -> tensor<128x12x32xf32>
%12 = linalg.generic {
indexing_maps = [
affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
affine_map<(d0, d1, d2) -> (d0, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>] ...} ins(%11 : tensor<128x12x32xf32>) ....
One of the things that needs to be done carefully to fuse the reshape with its producer (by collapsing the dimensions of the consumer op) is that you want to keep the indexing maps of the collapsed op as permutations as well. In this case, based on the indexing map for the operand that is to be fused, the dimensions d0
and d2
of the consumer linalg.generic
are to be collapsed. But the output is computed using the indexing map affine_map<(d0, d1, d2) -> (d0, d1, d2)>
which doesnt have d0
and d2
as "consecutive" dimensions. In such cases the folded op will have to use mods
and divs
to recreate the value of d0
and d2
that interferes with all transformations later on. So it is a deliberate choice to not fold the reshape with the generic op.
Is this a pessimization? I think not. It might look as though we havent fused the generic op with the matmul, but really if you look at the original code the ops have been fused with other uses, so in the end the number of dispatches created is the same. If I had to guess the sequence of operations initially it would have been
%10 = linalg.matmul
%11 = linalg.generic .. ins(%10 , .. : tensor<128x384xf32>, ...) // 2D op for bias -add
%12 = tensor.expand_shape %11 [[0], [1, 2]] : tensor<128x384xf32> -> tensor<128x12x32xf32>
%13 = linalg.generic { indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2) -> (d1, d0, d2)>] ...} ... // Transpose operation
If you didnt propagate the tensor.expand_shape
above the first linalg.generic
you would have fuse that op with linalg.matmul
, but the last linalg.generic
would be in its own dispatch. Now the first linalg.generic
is fused with the second linalg.generic
, so still two dispatches.
(I did go deep into looking into miniLM, and I did vaguely remember something like this, but had forgotten the details).
If there is something better that can be done here, I am all for it (but I dont see it). The main constraint though is that we need to keep the indexing maps permutations.
I am afraid that expand shape is <128x384xf32> -> <128x12x32xf32>. Does this change anything in your comment? The expand shape operation itself does nothing in the data if they are laid out contiguously, which is the case I believe.
What is the cost to have %11 as a separate node as shown in the graph? Shouldn't it be a no-op?
I am afraid that expand shape is <128x384xf32> -> <128x12x32xf32>. Does this change anything in your comment? The expand shape operation itself does nothing in the data if they are laid out contiguously, which is the case I believe.
Sorry, my bad. Fixed the comment, but the core of it holds. The reshape itself is just changing metadata, but when you try to move the expand "past the" generic op, you have to collapse the generic op to a 2D, but doing that will introduce mods and divs in the indexing maps, which is not allowed/not good for Linalg ops.
What is the cost to have %11 as a separate node as shown in the graph? Shouldn't it be a no-op?
The node in the graph itself is a no-op.
Thanks. It seems like a low priority then, unless we have a compelling reason to fuse it into a dispatch region; It may give more opportunity for advanced fusions such as reshaping a whole dispatch region with purely element-wise ops.
Changing the priority to 2.
closing stale issue
I found a new case that we might want to fuse them into a single dispatch. The result of matmul is only used by
flow.tensor.reshape
, and the result offlow.tensor.reshape
is only used by the element-wise op. I think we can reorder it tomatmul -> elementwise -> reshape
, then we can fusematmul + elementwise
into a single dispatch.