nod-ai / SHARK-Studio

SHARK Studio -- Web UI for SHARK+IREE High Performance Machine Learning Distribution
Apache License 2.0
1.42k stars 170 forks source link

DistilGPT2 to TOSA #494

Closed AmosLewis closed 1 year ago

AmosLewis commented 2 years ago

The script to run: https://gist.github.com/AmosLewis/465795bfa1a4004fb96c742666515027

THE FINAL TOSA FILE: distilgpt2_tosa_20230123_elide.mlir

distilgpt2_tosa_20230123.mlir

Other bugs to fix:

RefineTypes1 crash [MERGED] https://github.com/llvm/torch-mlir/issues/1599 RefineTypes2 unk [MERGED]

aten.amax : Has a decompose for it, but doesnot work for distilgpt2. DecomposeAtenAmaxOp torch.aten.select.int : Has a decompose for it, but doesnot work for distilgpt2. DecomposeAtenSelectIntOp

The 2 decomposed bugs are fixed by Ramiro: [MERGED] https://github.com/llvm/torch-mlir/pull/1750+https://github.com/llvm/torch-mlir/pull/1769 + [MERGED] https://github.com/llvm/torch-mlir/pull/1742 select.int negative index support https://github.com/llvm/torch-mlir/pull/1787

Ops to Fix:

aten.view [Merged] https://github.com/llvm/torch-mlir/pull/1768 torch.aten.slice.Tensor/select.int [Merged] https://github.com/llvm/torch-mlir/pull/1787 torch.prim.NumToTensor.Scalar f64 [TOSA NOT SUPPORT F64] https://github.com/llvm/torch-mlir/pull/1802

Ops to add:

torch.ops.prims.convert_element_type MERGED https://github.com/llvm/torch-mlir/pull/1619 Added a decompose aten.as_stride: Fixed by deleting decomposition of the torch.ops.aten.slice.Tensor in python make_fx code. The as_strided is generated by slice decompostion

Deprecated Chi patch https://github.com/llvm/torch-mlir/pull/1742, Deprecated Alec patch:tosa WIP , linalg wip

torch.aten.index.Tensor/torch.prim.ListConstruct [MERGED] https://github.com/llvm/torch-mlir/pull/1771, need to add convert for ListConstruct with non-const value.

AmosLewis commented 2 years ago

Bug

/home/alec/Nod/torch-mlir/mlir_venv/lib/python3.10/site-packages/torch/jit/_check.py:181: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
  warnings.warn("The TorchScript type system doesn't support "
python: /home/alec/Nod/torch-mlir/externals/llvm-project/llvm/include/llvm/Support/Casting.h:648: decltype(auto) llvm::dyn_cast(const From &) [To = mlir::IntegerType, From = mlir::Type]: Assertion `detail::isPresent(Val) && "dyn_cast on a non-existent value"' failed.
Aborted (core dumped)
JakopinA commented 2 years ago

Currently working on the op aten.as_strided

JakopinA commented 2 years ago

Still working on the op. Here's the draft PR: https://github.com/llvm/torch-mlir/pull/1656

JakopinA commented 1 year ago

Still working on the op, it's almost done. I've got the op passing if you pass in a 1d matrix, I just need to add support for matrices with greater rank.

AmosLewis commented 1 year ago

https://github.com/llvm/torch-mlir/compare/main...JakopinA:torch-mlir:jakopina-asstrided-tosa

AmosLewis commented 1 year ago

Bugs from torch_mlir upstream/main 20221220

➜  distillGPT2 git:(main) ✗ python distillgpt2.py
Some weights of the model checkpoint at distilgpt2 were not used when initializing GPT2ForSequenceClassification: ['lm_head.weight']
- This IS expected if you are initializing GPT2ForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GPT2ForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at distilgpt2 and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
model(test_input): 
tensor([[0.8815, 0.0087]], grad_fn=<IndexBackward0>)
/home/chi/src/ubuntu20/shark/SHARK/shark.venv/lib/python3.10/site-packages/torch/jit/_check.py:181: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
  warnings.warn("The TorchScript type system doesn't support "
python3.10: /home/chi/src/ubuntu20/shark/torch-mlir/include/torch-mlir/Dialect/Torch/IR/TorchTypes.h:77: mlir::Type mlir::torch::Torch::BaseTensorType::getDtype() const: Assertion `hasDtype() && "must have a dtype"' failed.
[24]    1539869 abort (core dumped)  python3.10 distillgpt2.py

Bug1 from DecomposeAtenAmaxOp https://github.com/llvm/torch-mlir/pull/1636 Bug2 from DecomposeAtenSelectIntOp https://github.com/llvm/torch-mlir/pull/398 Detail gpd trace: https://gist.github.com/AmosLewis/465795bfa1a4004fb96c742666515027

%171 = torch.aten.amax %169, %170, %true_0 : !torch.tensor, !torch.list<int>, !torch.bool -> !torch.tensor loc(#loc63) 
%840 = torch.aten.select.int %838, %int1, %int-1 : !torch.tensor, !torch.int, !torch.int -> !torch.tensor loc(#loc419)

1221 Update1: The reason for the crash is that the lowering for the op aten.as_strided is missing. It crashes before any op is lowered. Actually, if an op is not registered in Torch-MLIR, then you will see that op in the IR in the form of torch.operator .... Now, since the shape and type info is also missing for the op, so the respective passes for shape and type inference would not be able to generate any such info for the result of this op. Also, all the ops which use the result of this missing op as their operand would have this information missing. As a result, the subsequent passes before the actual lowering ones like say DecomposeComplexOps, if make use of the shape and/or type info of the result of this missing op will crash since that's not present.

If adding this op in the registry and adding the shape and type info, then the crash/failure will happen during the TorchToTosa lowering not at the earlier passes. And in this case the crash happens when the code tries to getDtype for the operand which doesn't have that info, it crashes due to an assertion.

1221 Update2: Try to add the as_stride with https://github.com/llvm/torch-mlir/pull/1742. Still got the same bug for DecomposeAtenAmaxOp. gdb detail in https://gist.github.com/AmosLewis/3d7872f07ddd3f9e6c07b6c311a900f2#file-distillgpt2_torch_delete_decompose_amax_selectint-mlir

1221 Update3: Update2 was also one of the reasons for failure. But since that is fixed now there's another reason for failure, which is the missing refine type info for the tanh op. @vivekkhandelwal1 have added that here: https://github.com/llvm/torch-mlir/pull/1745. Here's the torch IR generated: https://gist.github.com/3f912e3a9ba62ce2895533185f837b44

1222 Updates1: Ramiro will fix the decompose bug in https://github.com/llvm/torch-mlir/issues/1749 https://github.com/llvm/torch-mlir/pull/1750

AmosLewis commented 1 year ago

After the up 2 bug, got convert to tosa bug: torch-mlir-opt -convert-torch-to-tosa /tmp/_lambda.mlir -mlir-print-ir-after-all -mlir-disable-threading --mlir-pretty-debuginfo --mlir-print-ir-after-failure --mlir-print-op-on-diagnostic --mlir-print-debuginfo

//===-------------------------------------------===//
Legalizing operation : 'torch.aten.view'(0x9533eb0) {
  %128 = "torch.aten.view"(%arg1, %127) : (!torch.tensor, !torch.list<int>) -> !torch.tensor

  * Fold {
  } -> FAILURE : unable to fold

  * Pattern : 'torch.aten.view -> ()' {
Trying to match "(anonymous namespace)::ConvertAtenOp<mlir::torch::Torch::AtenViewOp>"
    ** Failure : Only tensor types are currently supported
"(anonymous namespace)::ConvertAtenOp<mlir::torch::Torch::AtenViewOp>" result 0
  } -> FAILURE : pattern failed to match
} -> FAILURE : no matched legalization pattern

From:

%arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[1,128],si64>
%0 = torch.prim.ListConstruct %int-1, %int128 : (!torch.int, !torch.int) -> !torch.list<int> [unknown]
%1 = torch.aten.view %arg1, %0 : !torch.tensor, !torch.list<int> -> !torch.tensor <eval_with_key>.2:5:11
AmosLewis commented 1 year ago

With the as_stride and decompose bug fixed, then we got torch.aten.slice.Tensor bug:

error: failed to legalize operation 'torch.aten.slice.Tensor' that was explicitly marked illegal
note: see current operation: %1652 = "torch.aten.slice.Tensor"(%1648, %1, %5, %3, %1) : (!torch.vtensor<[1,128,2],f32>, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.vtensor<[1,1,2],f32>
//===-------------------------------------------===//
Legalizing operation : 'torch.aten.slice.Tensor'(0x9dcedf0) {
  %1652 = "torch.aten.slice.Tensor"(%1648, %1, %5, %3, %1) : (!torch.vtensor<[1,128,2],f32>, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.vtensor<[1,1,2],f32>

  * Fold {
  } -> FAILURE : unable to fold

  * Pattern : 'torch.aten.slice.Tensor -> ()' {
Trying to match "(anonymous namespace)::ConvertAtenOp<mlir::torch::Torch::AtenSliceTensorOp>"
    ** Failure : Currently unsupported: start < 0
"(anonymous namespace)::ConvertAtenOp<mlir::torch::Torch::AtenSliceTensorOp>" result 0
  } -> FAILURE : pattern failed to match
} -> FAILURE : no matched legalization pattern
//===-------------------------------------------===//
<eval_with_key>.2:543:13: error: failed to legalize operation 'torch.aten.slice.Tensor' that was explicitly marked illegal
<eval_with_key>.2:543:13: note: see current operation: %1652 = "torch.aten.slice.Tensor"(%1648, %1, %5, %3, %1) : (!torch.vtensor<[1,128,2],f32>, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.vtensor<[1,1,2],f32>
ImplicitTypeIDRegistry::lookupOrInsert(mlir::detail::PreservedAnalyses::AllAnalysesType)

The weird thing is: I use the torch-mlir-opt -pass-pipeline='builtin.module(torch-backend-to-tosa-backend-pipeline)' /tmp/_lambda.mlir -mlir-print-ir-after-all -mlir-disable-threading --debug to get the final mlir file mixed with tosa and torch ops. https://gist.github.com/AmosLewis/01e4eaaab7f57940251679b8c818e932 . The torch.aten.slice.Tensor is the last few ops in this mlir. But before it, there is plenty of torch.ops still there, which means their rewrite pattern fails. But they didn't output a pattern mismatch error. This doesn't make sense to me. Because those ops, say aten.view lower to tosa code is there. If its rewrite fails, shouldn't it output error before torch.aten.slice.Tensor? I am planning to fix the torch.aten.slice.Tensor first after the as_strided to tosa is done.

AmosLewis commented 1 year ago

Based on distilgpt2_torchtotosa_debug.mlir, this is all the 34 ops that are not converted successfully in distilgpt

torch.aten.as_strided
torch.aten.view
torch.aten.arange.start_step 
torch.aten.unsqueeze
torch.aten.embedding
torch.aten.add.Tensor
torch.aten.native_layer_norm
torch.aten.mm
torch.aten.mul.Scalar
torch.aten.add.Tensor
torch.aten.permute 
torch.aten.transpose.int
torch.aten.broadcast_to
torch.aten.bmm
torch.aten.clone
torch.aten.div.Tensor
torch.aten.to.dtype
torch.aten.clone
torch.aten.where.self
torch.aten.max.dim
torch.aten.sub.Tensor
torch.aten.exp
torch.aten.sum.dim_IntList
torch.aten.div.Tensor
torch.aten.broadcast_to
torch.aten.bmm
torch.aten.pow.Tensor_Scalar
torch.aten.mul.Scalar
torch.aten.sum.dim_IntList
torch.aten.tanh
torch.aten.transpose.int 

torch.aten.slice.Tensor
torch.aten.squeeze.dim
torch.aten.index.Tensor
AmosLewis commented 1 year ago

After the as_stride and slice tensor fixed, only torch.aten.index.Tensor left: mlir file: https://gist.github.com/AmosLewis/f2d8fb5d121d38a18577aea5234b8d45

  %530 = torch_c.from_builtin_tensor %24 : tensor<1xi64> -> !torch.vtensor<[1],si64>
  %531 = "tosa.slice"(%529) {size = [1, 1, 2], start = [0, 127, 0]} : (tensor<1x128x2xf32>) -> tensor<1x1x2xf32>
  %532 = "tosa.reshape"(%531) {new_shape = [1, 2]} : (tensor<1x1x2xf32>) -> tensor<1x2xf32>
  %533 = torch_c.from_builtin_tensor %532 : tensor<1x2xf32> -> !torch.vtensor<[1,2],f32>
  %534 = torch.prim.ListConstruct %530 : (!torch.vtensor<[1],si64>) -> !torch.list<vtensor>
  %535 = torch.aten.index.Tensor %533, %534 : !torch.vtensor<[1,2],f32>, !torch.list<vtensor> -> !torch.vtensor<[1,2],f32>
  %536 = torch_c.to_builtin_tensor %535 : !torch.vtensor<[1,2],f32> -> tensor<1x2xf32>
//===-------------------------------------------===//
Legalizing operation : 'torch.prim.ListConstruct'(0x93a7830) {
  %534 = "torch.prim.ListConstruct"(%530) : (!torch.vtensor<[1],si64>) -> !torch.list<vtensor>

  * Fold {
  } -> FAILURE : unable to fold
} -> FAILURE : no matched legalization pattern
//===-------------------------------------------===//
<eval_with_key>.2:544:12: error: failed to legalize operation 'torch.prim.ListConstruct'
<eval_with_key>.2:544:12: note: see current operation: %534 = "torch.prim.ListConstruct"(%530) : (!torch.vtensor<[1],si64>) -> !torch.list<vtensor>
// -----// IR Dump After FinalizingBackendTypeConversion Failed (torch-finalizing-backend-type-conversion) //----- //
mlir-asm-printer: Verifying operation: func.func
AmosLewis commented 1 year ago

With transformers 4.25.1 New ops to be fixed. torch.prim.NumToTensor.Scalar f64

Traceback (most recent call last):
  File "/home/chi/src/ubuntu20/shark/torch-mlir/distillGPT2/distillgpt2.py", line 92, in <module>
    module = torch_mlir.compile(
  File "/home/chi/src/ubuntu20/shark/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/__init__.py", line 386, in compile
    run_pipeline_with_repro_report(
  File "/home/chi/src/ubuntu20/shark/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/compiler_utils.py", line 73, in run_pipeline_with_repro_report
    raise TorchMlirCompilerError(trimmed_message) from None
torch_mlir.compiler_utils.TorchMlirCompilerError: Lowering Torch Backend IR -> TOSA Backend IR failed with the following diagnostics:
error: failed to legalize operation 'torch.prim.NumToTensor.Scalar' that was explicitly marked illegal
note: see current operation: %318 = "torch.prim.NumToTensor.Scalar"(%188) : (!torch.float) -> !torch.vtensor<[],f64>

For Torch-MLIR developers, the error can be reproduced with:
$ torch-mlir-opt -pass-pipeline='builtin.module(torch-backend-to-tosa-backend-pipeline)' /tmp/_lambda.mlir
Add '-mlir-print-ir-after-all -mlir-disable-threading' to get the IR dump for debugging purpose.

Then torch.aten.to.dtype f64

Traceback (most recent call last):
  File "/home/chi/src/ubuntu20/shark/torch-mlir/distillGPT2/distillgpt2.py", line 92, in <module>
    module = torch_mlir.compile(
  File "/home/chi/src/ubuntu20/shark/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/__init__.py", line 386, in compile
    run_pipeline_with_repro_report(
  File "/home/chi/src/ubuntu20/shark/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/compiler_utils.py", line 73, in run_pipeline_with_repro_report
    raise TorchMlirCompilerError(trimmed_message) from None
torch_mlir.compiler_utils.TorchMlirCompilerError: Lowering Torch Backend IR -> TOSA Backend IR failed with the following diagnostics:
error: failed to legalize operation 'torch.aten.to.dtype' that was explicitly marked illegal
note: see current operation: %321 = "torch.aten.to.dtype"(%320, %8, %9, %9, %11) : (!torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none) -> !torch.vtensor<[],f32>
AmosLewis commented 1 year ago

f64 is a know issue for tosa, there is not a easy solution to fix it right now. https://github.com/llvm/torch-mlir/issues/1615 So I am planning to downgrade the transformers=4.21.2

AmosLewis commented 1 year ago
func.func @torch.aten.slice(%arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[1,128],si64>}) -> !torch.tensor {
  %int-1 = torch.constant.int -1
  %int128 = torch.constant.int 128
  %0 = torch.prim.ListConstruct %int-1, %int128 : (!torch.int, !torch.int) -> !torch.list<int>
  %1 = torch.aten.view %arg1, %0 : !torch.tensor, !torch.list<int> -> !torch.tensor
  return %1 : !torch.tensor
}
//===-------------------------------------------===//
Legalizing operation : 'torch.aten.view'(0x563996c2d170) {
  %3 = "torch.aten.view"(%arg0, %2) : (!torch.tensor, !torch.list<int>) -> !torch.tensor

  * Fold {
  } -> FAILURE : unable to fold

  * Pattern : 'torch.aten.view -> ()' {
Trying to match "(anonymous namespace)::ConvertAtenOp<mlir::torch::Torch::AtenViewOp>"
    ** Failure : Only tensor types are currently supported
"(anonymous namespace)::ConvertAtenOp<mlir::torch::Torch::AtenViewOp>" result 0
  } -> FAILURE : pattern failed to match
} -> FAILURE : no matched legalization pattern
//===-------------------------------------------===//
/tmp/view.mlir:5:8: error: failed to legalize operation 'torch.aten.view' that was explicitly marked illegal
  %1 = torch.aten.view %arg1, %0 : !torch.tensor, !torch.list<int> -> !torch.tensor
       ^
/tmp/view.mlir:5:8: note: see current operation: %3 = "torch.aten.view"(%arg0, %2) : (!torch.tensor, !torch.list<int>) -> !torch.tensor
ImplicitTypeIDRegistry::lookupOrInsert(mlir::detail::PreservedAnalyses::AllAnalysesType)
AmosLewis commented 1 year ago

https://gist.github.com/AmosLewis/f69629625defdcd9105c3216226e5e9f