llvm / torch-mlir

The Torch-MLIR project aims to provide first class support from the PyTorch ecosystem to the MLIR ecosystem.
Other
1.35k stars 504 forks source link

Bug in lowering of AtenView -> Tensor::Expand_Shape #2008

Open Abhishek-Varma opened 1 year ago

Abhishek-Varma commented 1 year ago

Input MLIR :-

func.func @view(%arg0: !torch.vtensor<[64,64],f16>) -> !torch.vtensor<[1,4096,1],f16>
{
  %int1 = torch.constant.int 1
  %int-1 = torch.constant.int -1
  %shape = torch.prim.ListConstruct %int1, %int-1, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
  %output = torch.aten.view %arg0, %shape : !torch.vtensor<[64,64],f16>, !torch.list<int> -> !torch.vtensor<[1,4096,1],f16>
  return %output : !torch.vtensor<[1,4096,1],f16>
}

On passing it through --convert-torch-to-linalg, following error is thrown :-

error: 'tensor.expand_shape' op expected reassociation map #0 of same rank as expanded memref(3), but got 2
note: see current operation: %expanded = "tensor.expand_shape"(%collapsed) {reassociation = [[0, 1]]} : (tensor<4096xf16>) -> tensor<1x4096x1xf16>

With a fix I have the correct IR gets generated : %expanded = tensor.expand_shape %collapsed [[0, 1, 2]] : tensor<4096xf16> into tensor<1x4096x1xf16> (Observe the reassociation map here)

But turns out that the current AtenViewOp's implementation seems to be reinventing the wheel for deciphering ReassociationIndices - I'll have to go through it a bit to make a generic patch.

This is needed to for ToMe's Support - I believe I should anyway work on raising a patch for the fix.

After addressing the above, I saw two instances where negative dimensions were being an issue in the LLVM pipeline :-

  1. DimOfReifyRankedShapedTypeOpInterface.
  2. FoldDimOfExpandShape.

I've patched them up as well, but two things here :-

  1. Should LLVM passes take care of Python's negative dimension cases? I believe that shouldn't be the case even though temporarily I've patched things up at externals/llvm-project. Let me know your thoughts here, I'd accordingly raise patch for llvm-project separately.
  2. Ideally all the negative dimension indexing should be normalized. Is there any such pass currently in torch-mlir that's supposed to take care of this? If there isn't, I believe it's worth investing an effort on this front.

@powderluv @ramiro050

ramiro050 commented 1 year ago

But turns out that the current AtenViewOp's implementation seems to be reinventing the wheel for deciphering ReassociationIndices - I'll have to go through it a bit to make a generic patch.

Yeah, the AtenViewOp conversion pattern has become quite complex. Any simplifications to it are definitely welcomed!

Ideally all the negative dimension indexing should be normalized. Is there any such pass currently in torch-mlir that's supposed to take care of this? If there isn't, I believe it's worth investing an effort on this front

Yes, all negative dimensions should be properly converted to positive dimensions in torch-mlir when converting to one of the three backends. All patterns converting to one of the three backends should being doing such transformations using the helpers toPositiveDim and isValidDim, but it's possible that some ops do not have such handling.

Abhishek-Varma commented 1 year ago

After addressing the above, I saw two instances where negative dimensions were being an issue in the LLVM pipeline :-

DimOfReifyRankedShapedTypeOpInterface. FoldDimOfExpandShape.

Raised a PR : https://github.com/llvm/torch-mlir/pull/2013 to deal with negative dimensions. Solved this aforementioned part - no need to patch things up at llvm-project.

createLinalgPayloadCalculationForGatherOps needs special mention here because this function especially led to all the red-herring issues at llvm-project.

Working on addressing AtenViewOp next.