microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
https://onnxruntime.ai
MIT License
14.83k stars 2.94k forks source link

The input tensor cannot be reshaped to the requested shape after adding Gather output to model's output #16670

Open mengniwang95 opened 1 year ago

mengniwang95 commented 1 year ago

Describe the issue

Add Gather's output to model's output will trigger this shape issue (except /model/embed_tokens/Gather node in llama). I found the first sample in NeelNanda/pile-10k dataset can run successfully but will fail from the second sample. image

To reproduce

Prepare model: optimum-cli export onnx --model decapoda-research/llama-7b-hf --task causal-lm-with-past --for-ort --device cpu llama-7b-onnx python -m onnxruntime.transformers.optimizer --input llama-7b-onnx/decoder_model.onnx --output llama-7b-onnx-opt/decoder_model.onnx --num_heads 32 --hidden_size 4096 --model_type gpt2 --use_external_data_format

Run: https://github.com/microsoft/onnxruntime-inference-examples/tree/main/quantization/language_model/llama#1-quantization

The model is too large to upload so you can get it from nc_workspace folder after you executing above cmd and its name is Optimized_model.onnx_augment.onnx

Urgency

No response

Platform

Linux

OS Version

7

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.14.1

ONNX Runtime API

Python

Architecture

X64

Execution Provider

Default CPU

Execution Provider Library Version

No response

wschin commented 1 year ago

My guess is that Gather's output shape is a fixed value (e.g., [3, 2, 2048] if you have batch 3, sequence length 2, and embedding dimension 2048. You can print out the ModelProto.graph.input and ModelProto.graph.output fields to inspect model's input and output schema. If so, you can manually correct this by editing ModelProto.graph.output[i].type (type: https://github.com/onnx/onnx/blob/b18533454a8ec0594d56aa55faca16e0bf6a0917/onnx/onnx.proto#L188) if i-th output is Gather's output. On the other hand, you can try first exporting to ONNX following https://pytorch.org/docs/stable/onnx.html#example-alexnet-from-pytorch-to-onnx to verify if ONNX exporter is correct.