iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.85k stars 620 forks source link

Fix Global Opt for 1x1 Filter Convs #19270

Open IanWood1 opened 8 hours ago

IanWood1 commented 8 hours ago

Using this issue to note down internal discussion about conv performance.

Problem

1x1 filter convolutions get converted to linalg.generic ops during the global optimization pipeline. For example:

%8 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%0, %2 : tensor<1x8x128x128xf16>, tensor<8x8x1x1xf16>) outs(%7 : tensor<1x8x128x128xf32>) -> tensor<1x8x128x128xf32>

Convert1X1FilterConv2DToMatmulPass (+ dropping unit dims) generalizes it to:

 %8 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d3, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%collapsed, %collapsed_0 : tensor<8x128x128xf16>, tensor<8x8xf16>) outs(%7 : tensor<8x128x128xf32>) {
    ^bb0(%in: f16, %in_2: f16, %out: f32):
      %16 = arith.extf %in : f16 to f32
      %17 = arith.extf %in_2 : f16 to f32
      %18 = arith.mulf %16, %17 : f32
      %19 = arith.addf %out, %18 : f32
      linalg.yield %19 : f32
    } -> tensor<8x128x128xf32>

This is problematic because the reduction dimension d3 is the outermost dimension of %collapsed (first input to the generic). Ideally, we would transpose %collapsed with the permutation [1, 2, 0] to make the reduction the innermost and expect the transpose to fuse with its producer. This doesn't get handled during transpose propagation either.

For some more context, there %collapsed is produced by transpose([1, 2, 0]) -> pad -> collapse_shape(remove unit dim). If the generic conv could fuse with that transpose, it would make the access contiguous along the reduction dimension.

Possible Solutions:

@MaheshRavishankar suggested either:


Related Issue: https://github.com/iree-org/iree/issues/19230

cc @qedawkins

IanWood1 commented 8 hours ago

I think we might want to reorder Convert1X1FilterConv2DToMatmulPass and ConvertConvToChannelsLastPass so that ConvertConvToChannelsLastPass runs first. They are adjacent so it should have limited side-effects and I think it does what we want: transposes inputs and converts to preferable conv op variants.

Switching them with no other changes results in:

%602 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%599, %__hoisted_tensor_8x16xf16 : tensor<128x128x8xf16>, tensor<8x16xf16>) outs(%601 : tensor<128x128x16xf32>) {
^bb0(%in: f16, %in_324: f16, %out: f32):
  %630 = arith.extf %in : f16 to f32
  %631 = arith.extf %in_324 : f16 to f32
  %632 = arith.mulf %630, %631 : f32
  %633 = arith.addf %out, %632 : f32
  linalg.yield %633 : f32
} -> tensor<128x128x16xf32>