wangzhaode / llm-export

llm-export can export llm model to onnx.
Apache License 2.0
187 stars 21 forks source link

导出Llama-2-7b-chat-ms模型的时候,concat部分报错 #29

Open mi-tao opened 6 months ago

mi-tao commented 6 months ago

报错提示

发生异常: AssertionError (note: full exception trace is shown but execution is paused at: _run_module_as_main) exception: no description File "/usr/local/lib/python3.10/site-packages/torch/onnx/symbolic_opset9.py", line 539, in cat assert all( File "/usr/local/lib/python3.10/site-packages/torch/onnx/symbolic_helper.py", line 306, in wrapper return fn(g, *args, **kwargs) File "/usr/local/lib/python3.10/site-packages/torch/onnx/symbolic_opset11.py", line 551, in cat return opset9.cat(g, tensor_list, dim) File "/usr/local/lib/python3.10/site-packages/torch/onnx/symbolic_helper.py", line 392, in wrapper return fn(g, *args, **kwargs) File "/usr/local/lib/python3.10/site-packages/torch/onnx/utils.py", line 1891, in _run_symbolic_function return symbolic_fn(graph_context, *inputs, **attrs) File "/usr/local/lib/python3.10/site-packages/torch/onnx/utils.py", line 665, in _optimize_graph graph = _C._jit_pass_onnx(graph, operator_export_type) File "/usr/local/lib/python3.10/site-packages/torch/onnx/utils.py", line 1117, in _model_to_graph graph = _optimize_graph( File "/usr/local/lib/python3.10/site-packages/torch/onnx/utils.py", line 1548, in _export graph, params_dict, torch_out = _model_to_graph( File "/usr/local/lib/python3.10/site-packages/torch/onnx/utils.py", line 506, in export _export( File "/work/hu/alg_sources/llm-export/llm_export.py", line 228, in export_block torch.onnx.export( File "/work/hu/alg_sources/llm-export/llm_export.py", line 251, in export_blocks self.export_block(i) File "/work/hu/alg_sources/llm-export/llm_export.py", line 868, in <module> llm_exporter.export_blocks() File "/usr/local/lib/python3.10/runpy.py", line 86, in _run_code exec(code, run_globals) File "/usr/local/lib/python3.10/runpy.py", line 196, in _run_module_as_main (Current frame) return _run_code(code, main_globals, None, AssertionError:

报错代码位置

@_onnx_symbolic("aten::cat")
@symbolic_helper.parse_args("v", "i")
@_beartype.beartype
def cat(g: jit_utils.GraphContext, tensor_list, dim):
    tensors = symbolic_helper._unpack_list(tensor_list)
    # torch.cat ignores empty tensors such as `torch.Tensor([])`
    # These needs to be removed as input from ONNX's concat too, otherwise shape inference
    # will likely fail due to inputs with different ranks (0 for empty tensor, > 0 for anything else)
    nonempty_tensors = []
    for t in tensors:
        if symbolic_helper._is_constant(t) and not symbolic_helper._get_tensor_dim_size(
            t, 0
        ):
            continue
        nonempty_tensors.append(t)
    assert len(nonempty_tensors) > 0
    assert all(
        [
            symbolic_helper._get_tensor_rank(nonempty_tensors[0]) is None
            or symbolic_helper._get_tensor_rank(t) is None
            or symbolic_helper._get_tensor_rank(t)
            == symbolic_helper._get_tensor_rank(nonempty_tensors[0])
            for t in nonempty_tensors
        ]
    )
    tensor_list.node().removeAllInputs()
    for t in nonempty_tensors:
        tensor_list.node().addInput(t)

    tensors = symbolic_helper._unpack_list(tensor_list)
    return g.op("Concat", *tensors, axis_i=dim)

情况

nonempty_tensors的两个tensor的shape分别为【1,32,0,128】【1,1,3,32,3,128】,不能够cat Llama-2-7b-chat-ms模型下载地址: https://modelscope.cn/models/modelscope/Llama-2-7b-ms/files python环境:

certifi                       2023.11.17
charset-normalizer            3.3.2
cmake                         3.28.1
coloredlogs                   15.0.1
filelock                      3.13.1
flatbuffers                   23.5.26
fsspec                        2023.12.2
huggingface-hub               0.20.1
humanfriendly                 10.0
idna                          3.6
Jinja2                        3.1.2
lit                           17.0.6
markdown-it-py                3.0.0
MarkupSafe                    2.1.3
mdurl                         0.1.2
mpmath                        1.3.0
networkx                      3.2.1
numpy                         1.25.2
nvidia-cublas-cu11            11.10.3.66
nvidia-cublas-cu12            12.1.3.1
nvidia-cuda-cupti-cu11        11.7.101
nvidia-cuda-cupti-cu12        12.1.105
nvidia-cuda-nvrtc-cu11        11.7.99
nvidia-cuda-nvrtc-cu12        12.1.105
nvidia-cuda-runtime-cu11      11.7.99
nvidia-cuda-runtime-cu12      12.1.105
nvidia-cudnn-cu11             8.5.0.96
nvidia-cudnn-cu12             8.9.2.26
nvidia-cufft-cu11             10.9.0.58
nvidia-cufft-cu12             11.0.2.54
nvidia-curand-cu11            10.2.10.91
nvidia-curand-cu12            10.3.2.106
nvidia-cusolver-cu11          11.4.0.1
nvidia-cusolver-cu12          11.4.5.107
nvidia-cusparse-cu11          11.7.4.91
nvidia-cusparse-cu12          12.1.0.106
nvidia-nccl-cu11              2.14.3
nvidia-nccl-cu12              2.18.1
nvidia-nvjitlink-cu12         12.3.101
nvidia-nvtx-cu11              11.7.91
nvidia-nvtx-cu12              12.1.105
onnx                          1.15.0
onnxruntime                   1.15.1
onnxsim                       0.4.35
packaging                     23.2
pip                           23.3.2
protobuf                      4.25.1
Pygments                      2.17.2
PyYAML                        6.0.1
regex                         2023.12.25
requests                      2.31.0
rich                          13.7.0
safetensors                   0.4.1
sentencepiece                 0.1.99
setuptools                    57.4.0
sympy                         1.12
tabulate                      0.9.0
tokenizers                    0.13.3
torch                         2.0.1
tqdm                          4.66.1
transformers                  4.31.0
transformers-stream-generator 0.0.4
triton                        2.0.0
typing_extensions             4.9.0
urllib3                       2.1.0
wheel                         0.42.0
mi-tao commented 6 months ago

通过和原始的modeling_llama.py文件比较,定位被修改的地方 image 通过改回两个 squeeze,导出成功 打断点调试 torch.squeez在前向中正常降低维度成[seq, dim],目前未搞清楚导致问题的原因