huggingface / optimum

🚀 Accelerate training and inference of 🤗 Transformers and 🤗 Diffusers with easy to use hardware optimization tools
https://huggingface.co/docs/optimum/main/
Apache License 2.0
2.58k stars 470 forks source link

ONNX support for Mixtral text-classification #1671

Open piyushdevlpr opened 9 months ago

piyushdevlpr commented 9 months ago

Feature request

I have been trying to optimize a mixtral model trained for text-classification, into its ONNX counterpart.

I have tried it using torch level extraction -

torch.onnx.export(
    model,
    args=tuple(inputs.values()),
    f=output_path,
    opset_version=15,
    do_constant_folding=True,
    input_names=["input_ids", "attention_mask"],
    output_names=["score"],
    dynamic_axes={
        "input_ids": {0: "batch_size", 1: "sequence"},
        "attention_mask": {0: "batch_size", 1: "sequence"},
        "score": {0: "batch_size"},
    },
    verbose=False,
)

Converting the model gives the following warning -

/home/piyush/miniconda3/envs/tf/lib/python3.10/site-packages/transformers/modeling_attn_mask_utils.py:114: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal: /home/piyush/miniconda3/envs/tf/lib/python3.10/site-packages/transformers/modeling_attn_mask_utils.py:162: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! if past_key_values_length > 0: /home/piyush/miniconda3/envs/tf/lib/python3.10/site-packages/transformers/models/mixtral/modeling_mixtral.py:177: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! if seq_len > self.max_seq_len_cached: /home/piyush/miniconda3/envs/tf/lib/python3.10/site-packages/transformers/models/mixtral/modeling_mixtral.py:329: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): /home/piyush/miniconda3/envs/tf/lib/python3.10/site-packages/transformers/models/mixtral/modeling_mixtral.py:336: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): /home/piyush/miniconda3/envs/tf/lib/python3.10/site-packages/transformers/models/mixtral/modeling_mixtral.py:348: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): /home/piyush/miniconda3/envs/tf/lib/python3.10/site-packages/transformers/models/mixtral/modeling_mixtral.py:823: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! if top_x.shape[0] == 0: /home/piyush/miniconda3/envs/tf/lib/python3.10/site-packages/transformers/models/mixtral/modeling_mixtral.py:827: TracerWarning: Converting a tensor to a Python list might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! top_x_list = top_x.tolist() /home/piyush/miniconda3/envs/tf/lib/python3.10/site-packages/transformers/models/mixtral/modeling_mixtral.py:828: TracerWarning: Converting a tensor to a Python list might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! idx_list = idx.tolist() /home/piyush/miniconda3/envs/tf/lib/python3.10/site-packages/torch/onnx/symbolic_opset9.py:5589: UserWarning: Exporting aten::index operator of advanced indexing in opset 15 is achieved by combination of multiple ONNX operators, including Reshape, Transpose, Concat, and Gather. If indices include negative values, the exported graph will produce incorrect results. warnings.warn( /home/piyush/miniconda3/envs/tf/lib/python3.10/site-packages/torch/onnx/symbolic_opset9.py:6350: UserWarning: Warning: ONNX export does not support duplicated values in 'index' field, this will cause the ONNX model to be incorrect.

Doing inference using this ONNX model gives the following error - InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Non-zero status code returned while running Expand node. Name:'/model/layers.0/block_sparse_moe/Expand_1' Status Message: invalid expand shape

Motivation

To use Mixtral models for fast inference using ONNX.

Your contribution

I am new to these things, but can definitely try to raise a PR.

claeyzre commented 9 months ago

Mixtral is MoE and MoE is not yet supported by ORT AFAIK. There is PR but didn't have any commit for 3 weeks now so I don't know if it's still a thing on their side.

mgiessing commented 7 months ago

Mixtral MoE was merged last week https://github.com/microsoft/onnxruntime/pull/19945

tp-nan commented 6 months ago

Mixtral MoE was merged last week microsoft/onnxruntime#19945

Hi, @mgiessing , do you know what's the proper way to get ONNX from PyTorch for the MoE model? torch.onnx.export might involve modifying conditional statements.