llvm / torch-mlir

The Torch-MLIR project aims to provide first class support from the PyTorch ecosystem to the MLIR ecosystem.
Other
1.35k stars 507 forks source link

Error while conversion of .mlir to .vmfb #2730

Open manishghop opened 10 months ago

manishghop commented 10 months ago

Model: https://huggingface.co/THUDM/chatglm2-6b

Issue: Error while using padding to the input tokens.

As in the inferencing, the output_token is appended to the input_ids in the next forward pass, the length of the input_ids in each subsequent forward pass increases by 1. To fix it, we tried to use padding but with several combinations of max_length (1,5,10,15,20). The execution flow breaks while converting .mlir to .vmfb.

1) The code to compile pytorch model to .mlir & .vmfb: https://gist.github.com/manishghop/b03bfc1daa2d24e7e6b3d3b11699d739

2) This what we are trying to replicate(This is without using shark) : https://gist.github.com/manishghop/44b5bac1205e9d457a6744ebf118ba63 Using this as a reference, we are trying to do replicate this using nod ai shark.

3) This is what we aim to do using shark: https://gist.github.com/manishghop/529225d5e7e609b679f53fc4272be05c. Ideally we want the vmfb model account for different sized prompts and not be bounded to always have a fixed shape, which currently is not possible.

Reason: During the inferencing, while generating new tokens we append the predicted tokens in the first forward pass to the input_ids for the subsequent forward passes. Initial code didn’t accounted for dynamic changes to the shape of input_ids, it expects the shape for input_ids which was passed while torch_mlir compilation. For ex: for “What is the capital of Canada?” -> shape: (1,9). While inferencing, it expects the shape for input_ids to always be (1,9) but in the second forward pass onwards the shape of input_ids will increase by 1 till we reach the stopping criteria, hence it pops an error.

Using python API, the process remains struck but while using cmd line args: $iree-compile ./chatglm2-6b-int4.mlir -o ./chatglm.vmfb The error we get is : image

ramiro050 commented 10 months ago

This seems like an IREE issue. I think it would be better to raise an issue in the IREE repo, so that people with the right expertise can look at it.

AmosLewis commented 10 months ago

Iree issue https://github.com/openxla/iree/issues/16068

AmosLewis commented 10 months ago

@manishghop I try to repeat the error but fail at torchscript. Which version of torch-mlir&iree-compiler are you using? My torch-mlir version is from today

commit c7452af4fa7b4139dbd8b78b388b84a08b8c1b7a
Author: Chi_Liu <22491986+AmosLewis@users.noreply.github.com>
Date:   Fri Jan 12 14:54:38 2024 -0800
    [MLIR][ONNX] Add OnnxToTorch support for Maxpool Op (#2695)

python export_chatglm2.py

Traceback (most recent call last):
  File "/home/chi/src/test/chatglm/export_chatglm2.py", line 131, in <module>
    module = torch_mlir.compile(
             ^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/__init__.py", line 458, in compile
    run_pipeline_with_repro_report(
  File "/home/chi/src/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/compiler_utils.py", line 73, in run_pipeline_with_repro_report
    raise TorchMlirCompilerError(trimmed_message) from None
torch_mlir.compiler_utils.TorchMlirCompilerError: Lowering TorchScript IR -> Torch Backend IR failed with the following diagnostics:

python exception: Failure while executing pass pipeline:
error: "aten::_scaled_dot_product_flash_attention_for_cpu"("<eval_with_key>.5":127:50): unsupported by backend contract: Unimplemented operator 'aten._scaled_dot_product_flash_attention_for_cpu'
note: "aten::_scaled_dot_product_flash_attention_for_cpu"("<eval_with_key>.5":127:50): see current operation: %576:2 = "torch.operator"(%573, %574, %575, %342, %436, %345, %345) <{name = "aten._scaled_dot_product_flash_attention_for_cpu"}> : (!torch.tensor<[1,32,20,128],unk>, !torch.tensor<[1,32,20,128],unk>, !torch.tensor<[1,32,20,128],unk>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.tensor, !torch.tensor)

For Torch-MLIR developers, the error can be reproduced with:
$ torch-mlir-opt -pass-pipeline='builtin.module(torchscript-module-to-torch-backend-pipeline{backend-legal-ops=quant.matmul_rhs_group_quant extra-library=/tmp/custom_op_extra_library.mlir})' /tmp/_lambda.mlir
Add '-mlir-print-ir-after-all -mlir-disable-threading' to get the IR dump for debugging purpose.

Here is the generated chatglm_elided.mlir

manishghop commented 10 months ago

@AmosLewis , last time I checked I used torch_mlir-20240104. Also the iree_compiler should be iree_compiler-20240104.

AmosLewis commented 10 months ago

@manishghop Find the error for torch-mlir "aten::_scaled_dot_product_flash_attention_for_cpu". It's an op rename from "aten::_scaled_dot_product_flash_attention" in pytorch. The decompose of aten::_scaled_dot_product_flash_attention_for_cpu is added in pytorch 2 days ago in https://github.com/pytorch/pytorch/pull/117390 and https://github.com/pytorch/pytorch/pull/117097 . And it has not been published in pytorch release package. We need to wait for a few days before it published. The state of art torch package is torch==2.3.0.dev20240110+cpu. But in parallel we can start to add torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, in SHARK/shark/importer/import_with_fx decomps_list.

AmosLewis commented 10 months ago

I manually apply the https://github.com/pytorch/pytorch/pull/117390 and https://github.com/pytorch/pytorch/pull/117097 in my local shark.venv And run the export_chatglm.py but got Bug:

Traceback (most recent call last):
  File "/home/chi/src/SHARK/../test/chatglm/export_chatglm2.py", line 121, in <module>
    ts_graph = import_with_fx(
               ^^^^^^^^^^^^^^^
  File "/home/chi/src/SHARK/shark/shark_importer.py", line 698, in import_with_fx
    from brevitas_examples.llm.llm_quant.sharded_mlir_group_export import (
  File "/home/chi/src/SHARK/shark.venv/lib/python3.11/site-packages/brevitas_examples/llm/llm_quant/sharded_mlir_group_export.py", line 58, in <module>
    from brevitas_examples.llm.llm_quant.mlir_custom_mm import brevitas_matmul_rhs_group_quant_library
  File "/home/chi/src/SHARK/shark.venv/lib/python3.11/site-packages/brevitas_examples/llm/llm_quant/mlir_custom_mm.py", line 12, in <module>
    from torch_mlir.dialects.torch.importer.jit_ir.build_tools.registry import \
ModuleNotFoundError: No module named 'torch_mlir.dialects.torch.importer'

Solution: Replace line 12-15 from torch_mlir.dialects.torch.importer.jit_ir.build_tools.registry with from torch_mlir.jit_ir_importer.build_tools.registry

OR use install latest version in github pip install git+https://github.com/Xilinx/brevitas.git to get brevitas==0.10.0

AmosLewis commented 10 months ago

After the previous brevitas solution, I got another bug:

[DEBUG] Compiling torchscript graph
Traceback (most recent call last):
  File "/home/chi/src/SHARK/../test/chatglm/export_chatglm2.py", line 131, in <module>
    module = torch_mlir.compile(
             ^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/__init__.py", line 364, in compile
    extra_library_file_name = _canon_extra_library(extra_library)
                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/__init__.py", line 262, in _canon_extra_library
    mlir_library = generate_library(extra_library_dict)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/jit_ir_importer/build_tools/library_generator.py", line 228, in generate_library
    mb.import_function(function)
RuntimeError: required keyword attribute 'v' is undefined

We usually get this error, when the torch version with which we building the torch-mlir and the torch one when we are compiling the model are different. Make the torch version same in mlir_venv and shark.venv should fix it. torch==2.3.0.dev20240111