Open IanWood1 opened 8 hours ago
I think we might want to reorder Convert1X1FilterConv2DToMatmulPass
and ConvertConvToChannelsLastPass
so that ConvertConvToChannelsLastPass
runs first. They are adjacent so it should have limited side-effects and I think it does what we want: transposes inputs and converts to preferable conv op variants.
Switching them with no other changes results in:
%602 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%599, %__hoisted_tensor_8x16xf16 : tensor<128x128x8xf16>, tensor<8x16xf16>) outs(%601 : tensor<128x128x16xf32>) {
^bb0(%in: f16, %in_324: f16, %out: f32):
%630 = arith.extf %in : f16 to f32
%631 = arith.extf %in_324 : f16 to f32
%632 = arith.mulf %630, %631 : f32
%633 = arith.addf %out, %632 : f32
linalg.yield %633 : f32
} -> tensor<128x128x16xf32>
Using this issue to note down internal discussion about conv performance.
Problem
1x1 filter convolutions get converted to
linalg.generic
ops during the global optimization pipeline. For example:Convert1X1FilterConv2DToMatmulPass
(+ dropping unit dims) generalizes it to:This is problematic because the reduction dimension
d3
is the outermost dimension of%collapsed
(first input to the generic). Ideally, we would transpose%collapsed
with the permutation[1, 2, 0]
to make the reduction the innermost and expect the transpose to fuse with its producer. This doesn't get handled during transpose propagation either.For some more context, there
%collapsed
is produced bytranspose([1, 2, 0]) -> pad -> collapse_shape(remove unit dim)
. If the generic conv could fuse with that transpose, it would make the access contiguous along the reduction dimension.Possible Solutions:
@MaheshRavishankar suggested either:
Global Optimization
to be: transpose propagation -> pad to intrinsic -> generalizationRelated Issue: https://github.com/iree-org/iree/issues/19230
cc @qedawkins