nod-ai / SHARK-ModelDev

Unified compiler/runtime for interfacing with PyTorch Dynamo.
Apache License 2.0
95 stars 48 forks source link

[fusion] Fold extsi operations into `linalg.conv_2d` operations. #755

Closed MaheshRavishankar closed 5 months ago

MaheshRavishankar commented 5 months ago

Using the instructions here https://github.com/nod-ai/playbook/blob/main/HOWTO/punet.md the following sequence of IR emitted before any fusion (so before the first elementwise op fusion pass) kicks in

  %58 = tensor.empty() : tensor<2x320x130x130xi32>
  %59 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} \
ins(%inserted_slice_90 : tensor<2x320x130x130xi8>) outs(%58 : tensor<2x320x130x130xi32>) {
  ^bb0(%in: i8, %out: i32):
    %6601 = arith.extsi %in : i8 to i32
    linalg.yield %6601 : i32
  } -> tensor<2x320x130x130xi32>
  %60 = tensor.empty() : tensor<320x320x3x3xi32>
  %61 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} \
ins(%__auto.down_blocks.0.resnets.0.conv1.weight3Aqs : tensor<320x320x3x3xi8>) outs(%60 : tensor<320x320x3x3xi32>) {
  ^bb0(%in: i8, %out: i32):
    %6601 = arith.extsi %in : i8 to i32
    linalg.yield %6601 : i32
  } -> tensor<320x320x3x3xi32>
  %62 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%59, %61 : tensor<2x320x130x130xi32>, tensor<320x320x3x3xi32>) outs(%37 : tensor<2x320x128x128xi3\
2>) -> tensor<2x320x128x128xi32>

Folding the extsi into the conv would make things a lot easier. @qedawkins can you point to where these exist already for matmul ops. I couldnt find it.

qedawkins commented 5 months ago

It is happening here: https://github.com/iree-org/iree/blob/695e1932dd6cf91f2de5fc1415f10fe85fd269f0/compiler/src/iree/compiler/GlobalOptimization/RaiseSpecialOps.cpp#L687

Looks like it's already written for convolution and contraction ops, but it's restricted to floating point. We just need to add a case for signed integer extends also.

qedawkins commented 5 months ago

I'm also realizing that this pattern only works for floats because matmul_unsigned and matmul have different extension semantics and doing this based on ContractionOpInterface doesn't tell us the extension semantics for the underlying named op.

MaheshRavishankar commented 5 months ago

So... do we close this issue since we cant actually get named ops to do this. I dont know off the top of my head how linalg matmul ops handle extension semantics. Is it possible to have opdsl generate the sign-extensions by default.

qedawkins commented 5 months ago

Sorry was unclear, there is no need to close this issue. I was just saying that the way that the pattern in RaiseSpecialOps is written won't work for integer extends because the extension semantics are op specific. We will have to do integer op by op, but the skeleton of the pattern I linked will still work.