Open HahaLan97 opened 2 months ago
I found out that from the beginning, which I mean the mlir program generated from torch-mlir, has this generic op, which input and output are in different shapes as this:
%58 = linalg.generic {indexing_maps = [#map5, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%57#1 : tensor<1x1xf32>) outs(%57#0 : tensor<1x512x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%65 = arith.divf %out, %in : f32
linalg.yield %65 : f32
} -> tensor<1x512x1x1xf32>
But because the sizes are both 1, so it'll apply the pattern SplitElementwiseGenericOp
in the simplify-copy
pass, which will definitely cause the error due to impossible creating of memref::CopyOp
.
However, you're using applyPatternsAndFoldGreedily
, so I added one check in the pattern. If the shapes are indeed different, this pattern will fail and the pass still goes on. Then it works perfectly with generating the correct affine loops.
Furthermore, the resnet18 stucked at the pass func-preprocess
, because of the patterns aiming to convert two integer arithmetic operations, i.e. arith.addi
and arith.muli
. The cause of these is because you use:
if (auto lhs = add.getLhs().getDefiningOp<arith::ConstantIndexOp>(); isValidDim(add.getRhs()))
here, which will only examine if isValidDim is true. Change to this fix the error:
auto lhs = add.getLhs().getDefiningOp<arith::ConstantIndexOp>();
if (lhs != nullptr && isValidDim(add.getRhs())) {
// code
}
The other three CNNs don't have the same problem as resnet18 because they don't have such GenericOp from torch-mlir and there is no integer op. I put all the changes in my forked repo and opened this PR, hope it looks good to you.
P.S.
As title I've run into this error when trying to use the scaleflow-pytorch-pipeline on the generated mlir code.
I add the debug-point option into the pipeline and trace this error to
SimplifyCopyPass
, in which amemref.copy
will be created usingin
andout
fromlinalg.generic
. And as you can see above, they are not in the same shape. So I wonder the issue behind this is either the GenericOp is wrong or the pass is. Or maybe even because of wrong version of torch-mlir?I hope the contributors of this repo could help ;D @hanchenye @signorgelato @jeonghm9764