xenova / transformers.js

State-of-the-art Machine Learning for the web. Run 🤗 Transformers directly in your browser, with no need for a server!
https://huggingface.co/docs/transformers.js
Apache License 2.0
11.04k stars 678 forks source link

add onnxslim intergration #811

Closed inisis closed 3 months ago

inisis commented 3 months ago

Hi, as discussed here, https://github.com/xenova/transformers.js/issues/797, onnxslim can reduce number of operators, and some performace test can be seen https://github.com/huggingface/optimum/issues/1744

xenova commented 3 months ago

Perfect! Can you also add onnxslim to scripts/requirements.txt (the version with support for subgraphs + weight tying)?

inisis commented 3 months ago

Hi @xenova onnxslim version fixed

inisis commented 3 months ago

Hi @xenova do you have tests scripts for model accuracy and speed, I have gpu server available for tests.

xenova commented 3 months ago

Hi @xenova onnxslim version fixed

Great! I'm testing out the PR now, with a bunch of mobilenet models. Looks like it's working great!

The ONNX export succeeded and the exported model was saved at: models/google/mobilenet_v2_1.0_224
+-------------------+--------------------------------------+--------------------------------------+
|    Model Name     |              model.onnx              |              Op Set: 11              |
+-------------------+--------------------------------------+--------------------------------------+
|    Model Info     |            Original Model            |            Slimmed Model             |
+-------------------+--------------------------------------+--------------------------------------+
| IN: pixel_values  | float32: ('batch_size', 3, 224, 224) | float32: ('batch_size', 3, 224, 224) |
|    OUT: logits    |    float32: ('batch_size', 1001)     |    float32: ('batch_size', 1001)     |
+-------------------+--------------------------------------+--------------------------------------+
|        Add        |                  10                  |                  10                  |
|       Cast        |                  52                  |                  0                   |
|       Clip        |                  35                  |                  35                  |
|      Concat       |                  52                  |                  0                   |
|     Constant      |                 538                  |                  0                   |
|  ConstantOfShape  |                  52                  |                  0                   |
|       Conv        |                  52                  |                  52                  |
|      Flatten      |                  1                   |                  1                   |
|       Gemm        |                  1                   |                  1                   |
| GlobalAveragePool |                  1                   |                  1                   |
|        Pad        |                  52                  |                  0                   |
|      Reshape      |                 104                  |                  0                   |
|       Slice       |                  52                  |                  0                   |
|     Transpose     |                  52                  |                  0                   |
+-------------------+--------------------------------------+--------------------------------------+
|    Model Size     |               13.50 MB               |               13.35 MB               |
+-------------------+--------------------------------------+--------------------------------------+
|   Elapsed Time    |                                    1.25 s                                   |
+-------------------+--------------------------------------+--------------------------------------+

Huge improvements!

Hi @xenova do you have tests scripts for model accuracy and speed, I have gpu server available for tests.

I have a bunch of unorganized colab notebooks spread out, but nothing official and good for release. I would absolutely love to consolidate everything into a single evaluation script, so it can be used for evaluating different quantization settings too. Is this something your (or another community member) would be interested in developing?

inisis commented 3 months ago

Hi @xenova onnxslim version fixed

Great! I'm testing out the PR now, with a bunch of mobilenet models. Looks like it's working great!

The ONNX export succeeded and the exported model was saved at: models/google/mobilenet_v2_1.0_224
+-------------------+--------------------------------------+--------------------------------------+
|    Model Name     |              model.onnx              |              Op Set: 11              |
+-------------------+--------------------------------------+--------------------------------------+
|    Model Info     |            Original Model            |            Slimmed Model             |
+-------------------+--------------------------------------+--------------------------------------+
| IN: pixel_values  | float32: ('batch_size', 3, 224, 224) | float32: ('batch_size', 3, 224, 224) |
|    OUT: logits    |    float32: ('batch_size', 1001)     |    float32: ('batch_size', 1001)     |
+-------------------+--------------------------------------+--------------------------------------+
|        Add        |                  10                  |                  10                  |
|       Cast        |                  52                  |                  0                   |
|       Clip        |                  35                  |                  35                  |
|      Concat       |                  52                  |                  0                   |
|     Constant      |                 538                  |                  0                   |
|  ConstantOfShape  |                  52                  |                  0                   |
|       Conv        |                  52                  |                  52                  |
|      Flatten      |                  1                   |                  1                   |
|       Gemm        |                  1                   |                  1                   |
| GlobalAveragePool |                  1                   |                  1                   |
|        Pad        |                  52                  |                  0                   |
|      Reshape      |                 104                  |                  0                   |
|       Slice       |                  52                  |                  0                   |
|     Transpose     |                  52                  |                  0                   |
+-------------------+--------------------------------------+--------------------------------------+
|    Model Size     |               13.50 MB               |               13.35 MB               |
+-------------------+--------------------------------------+--------------------------------------+
|   Elapsed Time    |                                    1.25 s                                   |
+-------------------+--------------------------------------+--------------------------------------+

Huge improvements!

Hi @xenova do you have tests scripts for model accuracy and speed, I have gpu server available for tests.

I have a bunch of unorganized colab notebooks spread out, but nothing official and good for release. I would absolutely love to consolidate everything into a single evaluation script, so it can be used for evaluating different quantization settings too. Is this something your (or another community member) would be interested in developing?

Can I work with you

xenova commented 3 months ago

Can I work with you

To improve the conversion script and evaluation? I would love that! I'm currently working on some other things (like Florence2 support), but feel free to submit a PR and I can review it 😎

xenova commented 3 months ago

Merged the PR! Thanks so much! 🔥

inisis commented 3 months ago

I'm now writing scripts to test all the models in hugging face under xenova namespace

inisis commented 3 months ago

btw, the version 0.1.29.1 is unstable, because we have fixed https://github.com/inisis/OnnxSlim/issues/10, I would recommend the latest version 0.1.31

xenova commented 3 months ago

I've updated the version to 0.1.31 🤗👍

xenova commented 2 months ago

@inisis I'm running into a few issues when quantizing models produced by onnxslim.

Here's an example model: decoder_with_past_model.zip

Quantization code:

import onnx
from onnxruntime.quantization import quantize_dynamic, QuantType

model_fp32 = './decoder_with_past_model.onnx'
model_quant = './decoder_with_past_model_quantized.onnx'
quantized_model = quantize_dynamic(
    model_input=model_fp32,
    model_output=model_quant,
    weight_type=QuantType.QInt8,
    extra_options={'EnableSubgraph': True},
    per_channel=False,
    reduce_range=False,
)

this works

but if you run

onnxslim ./decoder_with_past_model.onnx ./decoder_with_past_model_slimmed.onnx

followed by

import onnx
from onnxruntime.quantization import quantize_dynamic, QuantType

model_fp32 = './onnx/decoder_with_past_model_slimmed.onnx'
model_quant = './onnx/decoder_with_past_model_slimmed_quantized.onnx'
quantized_model = quantize_dynamic(
    model_input=model_fp32,
    model_output=model_quant,
    weight_type=QuantType.QInt8,
    extra_options={'EnableSubgraph': True},
    per_channel=False,
    reduce_range=False,
)

it produces the following error:

[/usr/local/lib/python3.10/dist-packages/onnxruntime/quantization/onnx_quantizer.py](https://localhost:8080/#) in quantize_model(self)
    417             _, initializers_not_found = self.model.clean_initializers()
    418             if len(initializers_not_found) > 0:
--> 419                 raise RuntimeError("Invalid model with unknown initializers/tensors." + str(initializers_not_found))
    420 
    421         self.model.model.producer_name = __producer__

RuntimeError: Invalid model with unknown initializers/tensors.{'cross_attentions.0'}

can you look into this? 🙏

inisis commented 2 months ago

@xenova sorry for the bug, you can try

pip install git+https://github.com/inisis/OnnxSlim@main

btw, can I add you on linkedin

xenova commented 2 months ago

Thanks for the quick fix! I can confirm it works. 👍 I will update to the latest version in requirements.txt when you release 👌

btw, can I add you on linkedin

Sure, feel free to send a request! :)


I also ran into an issue with https://github.com/snakers4/silero-vad/blob/master/files/silero_vad.onnx I get the following error:

Traceback (most recent call last):
  File "/usr/local/bin/onnxslim", line 8, in <module>
    sys.exit(main())
  File "/usr/local/lib/python3.10/dist-packages/onnxslim/cli/_main.py", line 271, in main
    slim(
  File "/usr/local/lib/python3.10/dist-packages/onnxslim/cli/_main.py", line 128, in slim
    model = optimize(model, skip_fusion_patterns)
  File "/usr/local/lib/python3.10/dist-packages/onnxslim/core/slim.py", line 111, in optimize
    model = optimize_model(graph, skip_fusion_patterns)
  File "/usr/local/lib/python3.10/dist-packages/onnxslim/core/optimizer.py", line 891, in optimize_model
    graph_constant_fold_inplace(graph)
  File "/usr/local/lib/python3.10/dist-packages/onnxslim/core/optimizer.py", line 113, in graph_constant_fold_inplace
    graph_constant_fold_inplace(subgraph)
  File "/usr/local/lib/python3.10/dist-packages/onnxslim/core/optimizer.py", line 113, in graph_constant_fold_inplace
    graph_constant_fold_inplace(subgraph)
  File "/usr/local/lib/python3.10/dist-packages/onnxslim/core/optimizer.py", line 117, in graph_constant_fold_inplace
    delete_node(node)
  File "/usr/local/lib/python3.10/dist-packages/onnxslim/core/optimizer.py", line 88, in delete_node
    input_node = node.i()
  File "/usr/local/lib/python3.10/dist-packages/onnxslim/onnx_graphsurgeon/ir/node.py", line 88, in i
    return self.inputs[tensor_idx].inputs[producer_idx]
IndexError: list index out of range
inisis commented 2 months ago

Thanks for the quick fix! I can confirm it works. 👍 I will update to the latest version in requirements.txt when you release 👌

btw, can I add you on linkedin

Sure, feel free to send a request! :)

I also ran into an issue with https://github.com/snakers4/silero-vad/blob/master/files/silero_vad.onnx I get the following error:

Traceback (most recent call last):
  File "/usr/local/bin/onnxslim", line 8, in <module>
    sys.exit(main())
  File "/usr/local/lib/python3.10/dist-packages/onnxslim/cli/_main.py", line 271, in main
    slim(
  File "/usr/local/lib/python3.10/dist-packages/onnxslim/cli/_main.py", line 128, in slim
    model = optimize(model, skip_fusion_patterns)
  File "/usr/local/lib/python3.10/dist-packages/onnxslim/core/slim.py", line 111, in optimize
    model = optimize_model(graph, skip_fusion_patterns)
  File "/usr/local/lib/python3.10/dist-packages/onnxslim/core/optimizer.py", line 891, in optimize_model
    graph_constant_fold_inplace(graph)
  File "/usr/local/lib/python3.10/dist-packages/onnxslim/core/optimizer.py", line 113, in graph_constant_fold_inplace
    graph_constant_fold_inplace(subgraph)
  File "/usr/local/lib/python3.10/dist-packages/onnxslim/core/optimizer.py", line 113, in graph_constant_fold_inplace
    graph_constant_fold_inplace(subgraph)
  File "/usr/local/lib/python3.10/dist-packages/onnxslim/core/optimizer.py", line 117, in graph_constant_fold_inplace
    delete_node(node)
  File "/usr/local/lib/python3.10/dist-packages/onnxslim/core/optimizer.py", line 88, in delete_node
    input_node = node.i()
  File "/usr/local/lib/python3.10/dist-packages/onnxslim/onnx_graphsurgeon/ir/node.py", line 88, in i
    return self.inputs[tensor_idx].inputs[producer_idx]
IndexError: list index out of range

@xenova bug fixed, I swear I have never seen so many onnx models with this many subgraphs, thanks for reporting this!

xenova commented 2 months ago

Thanks! The model does export correctly, but now it produces:

[ONNXRuntimeError] : 10 : INVALID_GRAPH : Load model from model.onnx failed:This is an invalid model. In Node, ("If_0", If, "", -1) : ("Equal_0_C": tensor(bool),) -> ("If_0_outputs_0": tensor(float),"If_0_outputs_1": tensor(float),) , Error Nodes in a graph must be topologically sorted, however input 'If_0_else_branch__Inline_0__/decoder/If_10_output_1' of node: 
name: If_0_else_branch__Inline_0__/decoder/Unsqueeze_20 OpType: Unsqueeze
 is not output of any previous nodes.

Example code:

import onnxruntime as ort
import numpy as np

batch_size = 2
input = np.zeros((batch_size, 256), dtype=np.float32)
sr = np.array(16000)
state = np.zeros((2, batch_size, 128), dtype=np.float32)

ort_sess = ort.InferenceSession('model.onnx')
outputs = ort_sess.run(None, {'input': input, 'sr': sr, 'state': state})

# Print Result
outputs
inisis commented 2 months ago

Thanks! The model does export correctly, but now it produces:

[ONNXRuntimeError] : 10 : INVALID_GRAPH : Load model from model.onnx failed:This is an invalid model. In Node, ("If_0", If, "", -1) : ("Equal_0_C": tensor(bool),) -> ("If_0_outputs_0": tensor(float),"If_0_outputs_1": tensor(float),) , Error Nodes in a graph must be topologically sorted, however input 'If_0_else_branch__Inline_0__/decoder/If_10_output_1' of node: 
name: If_0_else_branch__Inline_0__/decoder/Unsqueeze_20 OpType: Unsqueeze
 is not output of any previous nodes.

Example code:

import onnxruntime as ort
import numpy as np

batch_size = 2
input = np.zeros((batch_size, 256), dtype=np.float32)
sr = np.array(16000)
state = np.zeros((2, batch_size, 128), dtype=np.float32)

ort_sess = ort.InferenceSession('model.onnx')
outputs = ort_sess.run(None, {'input': input, 'sr': sr, 'state': state})

# Print Result
outputs

Sorry for the bug, I have fixed it, thanks for you patience and help.

xenova commented 2 months ago

Thanks so much and no worries! 🤗

I've noticed another issue with quantization, where quantization after slimming the model does not work. Steps to reproduce:

  1. Download model
    wget https://huggingface.co/onnx-community/whisper-tiny.en_timestamped/resolve/d4469fcf29fc2898f0d57632d811fa0ed21de5cc/onnx/decoder_model_merged.onnx
  2. Run ONNXSlim
    onnxslim decoder_model_merged.onnx decoder_model_merged_slimmed.onnx
  3. Quantize the model
    
    import onnx
    from onnxruntime.quantization import quantize_dynamic, QuantType

model_fp32 = './decoder_model_merged_slimmed.onnx' model_quant = './decoder_model_merged_slimmed_quantized.onnx' quantized_model = quantize_dynamic( model_input=model_fp32, model_output=model_quant, weight_type=QuantType.QInt8, extra_options={'EnableSubgraph': True}, per_channel=False, reduce_range=False, )


4. Check model size:
```sh
$ ls -l
-rw-r--r-- 1 root root 118606545 Jul  2 14:01 decoder_model_merged.onnx
-rw-r--r-- 1 root root 118662672 Jul  2 14:13 decoder_model_merged_slimmed.onnx
-rw-r--r-- 1 root root 110356751 Jul  2 14:14 decoder_model_merged_slimmed_quantized.onnx

However, if you were to run quantization on the original model, you'd get a much smaller output:

import onnx
from onnxruntime.quantization import quantize_dynamic, QuantType

model_fp32 = './decoder_model_merged.onnx'
model_quant = './decoder_model_merged_quantized.onnx'
quantized_model = quantize_dynamic(
    model_input=model_fp32,
    model_output=model_quant,
    weight_type=QuantType.QInt8,
    extra_options={'EnableSubgraph': True},
    per_channel=False,
    reduce_range=False,
)
-rw-r--r-- 1 root root  30791841 Jul  2 14:15 decoder_model_merged_quantized.onnx

Any idea what's going wrong? Thanks!

xenova commented 2 months ago

and another weird one: this model becomes empty?

wget https://huggingface.co/onnx-community/whisper-tiny_timestamped/resolve/ae48508b4bc9b594a3a84d21f4a365a29d8d66ad/onnx/decoder_model_merged_fp16.onnx

followed by:

onnxslim decoder_model_merged_fp16.onnx decoder_model_merged_fp16_slimmed.onnx

produces:

+-------------------------------------+-----------------------------------------+-----------------------------------------+
|             Model Name              |     decoder_model_merged_fp16.onnx      |               Op Set: 14                |
+-------------------------------------+-----------------------------------------+-----------------------------------------+
|             Model Info              |             Original Model              |              Slimmed Model              |
+-------------------------------------+-----------------------------------------+-----------------------------------------+
|            IN: input_ids            |          int64: ('batch_size',          |                                         |
|                                     |       'decoder_sequence_length')        |                                         |
|      IN: encoder_hidden_states      |         float32: ('batch_size',         |                                         |
|                                     |   'encoder_sequence_length / 2', 384)   |                                         |
|  IN: past_key_values.0.decoder.key  |       float32: ('batch_size', 6,        |                                         |
|                                     |   'past_decoder_sequence_length', 64)   |                                         |
| IN: past_key_values.0.decoder.value |       float32: ('batch_size', 6,        |                                         |
|                                     |   'past_decoder_sequence_length', 64)   |                                         |
|  IN: past_key_values.0.encoder.key  |       float32: ('batch_size', 6,        |                                         |
|                                     |   'encoder_sequence_length_out', 64)    |                                         |
| IN: past_key_values.0.encoder.value |       float32: ('batch_size', 6,        |                                         |
|                                     |   'encoder_sequence_length_out', 64)    |                                         |
|  IN: past_key_values.1.decoder.key  |       float32: ('batch_size', 6,        |                                         |
|                                     |   'past_decoder_sequence_length', 64)   |                                         |
| IN: past_key_values.1.decoder.value |       float32: ('batch_size', 6,        |                                         |
|                                     |   'past_decoder_sequence_length', 64)   |                                         |
|  IN: past_key_values.1.encoder.key  |       float32: ('batch_size', 6,        |                                         |
|                                     |   'encoder_sequence_length_out', 64)    |                                         |
| IN: past_key_values.1.encoder.value |       float32: ('batch_size', 6,        |                                         |
|                                     |   'encoder_sequence_length_out', 64)    |                                         |
|  IN: past_key_values.2.decoder.key  |       float32: ('batch_size', 6,        |                                         |
|                                     |   'past_decoder_sequence_length', 64)   |                                         |
| IN: past_key_values.2.decoder.value |       float32: ('batch_size', 6,        |                                         |
|                                     |   'past_decoder_sequence_length', 64)   |                                         |
|  IN: past_key_values.2.encoder.key  |       float32: ('batch_size', 6,        |                                         |
|                                     |   'encoder_sequence_length_out', 64)    |                                         |
| IN: past_key_values.2.encoder.value |       float32: ('batch_size', 6,        |                                         |
|                                     |   'encoder_sequence_length_out', 64)    |                                         |
|  IN: past_key_values.3.decoder.key  |       float32: ('batch_size', 6,        |                                         |
|                                     |   'past_decoder_sequence_length', 64)   |                                         |
| IN: past_key_values.3.decoder.value |       float32: ('batch_size', 6,        |                                         |
|                                     |   'past_decoder_sequence_length', 64)   |                                         |
|  IN: past_key_values.3.encoder.key  |       float32: ('batch_size', 6,        |                                         |
|                                     |   'encoder_sequence_length_out', 64)    |                                         |
| IN: past_key_values.3.encoder.value |       float32: ('batch_size', 6,        |                                         |
|                                     |   'encoder_sequence_length_out', 64)    |                                         |
|        IN: use_cache_branch         |               bool: (1,)                |               bool: (1,)                |
|     OUT: present.0.encoder.key      |       float32: ('batch_size', 6,        |         float32: (0, 6, 1, 64)          |
|                                     |   'encoder_sequence_length_out', 64)    |                                         |
|    OUT: present.3.encoder.value     |       float32: ('batch_size', 6,        |         float32: (0, 6, 1, 64)          |
|                                     |   'encoder_sequence_length_out', 64)    |                                         |
|             OUT: logits             |         float32: ('batch_size',         |         float32: ('batch_size',         |
|                                     |    'decoder_sequence_length', 51865)    |    'decoder_sequence_length', 51865)    |
|    OUT: present.3.decoder.value     |       float32: ('batch_size', 6,        |       float32: ('batch_size', 6,        |
|                                     | 'past_decoder_sequence_length + 1', 64) | 'past_decoder_sequence_length + 1', 64) |
|     OUT: present.3.encoder.key      |       float32: ('batch_size', 6,        |         float32: (0, 6, 1, 64)          |
|                                     |   'encoder_sequence_length_out', 64)    |                                         |
|     OUT: present.2.encoder.key      |       float32: ('batch_size', 6,        |         float32: (0, 6, 1, 64)          |
|                                     |   'encoder_sequence_length_out', 64)    |                                         |
|    OUT: present.0.decoder.value     |       float32: ('batch_size', 6,        |       float32: ('batch_size', 6,        |
|                                     | 'past_decoder_sequence_length + 1', 64) | 'past_decoder_sequence_length + 1', 64) |
|     OUT: present.1.encoder.key      |       float32: ('batch_size', 6,        |         float32: (0, 6, 1, 64)          |
|                                     |   'encoder_sequence_length_out', 64)    |                                         |
|     OUT: present.2.decoder.key      |       float32: ('batch_size', 6,        |       float32: ('batch_size', 6,        |
|                                     | 'past_decoder_sequence_length + 1', 64) | 'past_decoder_sequence_length + 1', 64) |
|       OUT: cross_attentions.1       |       float32: ('batch_size', 6,        |       float32: ('batch_size', 6,        |
|                                     |       'decoder_sequence_length',        |       'decoder_sequence_length',        |
|                                     |     'encoder_sequence_length_out')      |     'encoder_sequence_length_out')      |
|    OUT: present.1.decoder.value     |       float32: ('batch_size', 6,        |       float32: ('batch_size', 6,        |
|                                     | 'past_decoder_sequence_length + 1', 64) | 'past_decoder_sequence_length + 1', 64) |
|     OUT: present.0.decoder.key      |       float32: ('batch_size', 6,        |       float32: ('batch_size', 6,        |
|                                     | 'past_decoder_sequence_length + 1', 64) | 'past_decoder_sequence_length + 1', 64) |
|    OUT: present.2.decoder.value     |       float32: ('batch_size', 6,        |       float32: ('batch_size', 6,        |
|                                     | 'past_decoder_sequence_length + 1', 64) | 'past_decoder_sequence_length + 1', 64) |
|       OUT: cross_attentions.0       |       float32: ('batch_size', 6,        |       float32: ('batch_size', 6,        |
|                                     |       'decoder_sequence_length',        |       'decoder_sequence_length',        |
|                                     |     'encoder_sequence_length_out')      |     'encoder_sequence_length_out')      |
|       OUT: cross_attentions.2       |       float32: ('batch_size', 6,        |       float32: ('batch_size', 6,        |
|                                     |       'decoder_sequence_length',        |       'decoder_sequence_length',        |
|                                     |     'encoder_sequence_length_out')      |     'encoder_sequence_length_out')      |
|       OUT: cross_attentions.3       |       float32: ('batch_size', 6,        |       float32: ('batch_size', 6,        |
|                                     |       'decoder_sequence_length',        |       'decoder_sequence_length',        |
|                                     |     'encoder_sequence_length_out')      |     'encoder_sequence_length_out')      |
|    OUT: present.1.encoder.value     |       float32: ('batch_size', 6,        |         float32: (0, 6, 1, 64)          |
|                                     |   'encoder_sequence_length_out', 64)    |                                         |
|     OUT: present.3.decoder.key      |       float32: ('batch_size', 6,        |       float32: ('batch_size', 6,        |
|                                     | 'past_decoder_sequence_length + 1', 64) | 'past_decoder_sequence_length + 1', 64) |
|     OUT: present.1.decoder.key      |       float32: ('batch_size', 6,        |       float32: ('batch_size', 6,        |
|                                     | 'past_decoder_sequence_length + 1', 64) | 'past_decoder_sequence_length + 1', 64) |
|    OUT: present.0.encoder.value     |       float32: ('batch_size', 6,        |         float32: (0, 6, 1, 64)          |
|                                     |   'encoder_sequence_length_out', 64)    |                                         |
|    OUT: present.2.encoder.value     |       float32: ('batch_size', 6,        |         float32: (0, 6, 1, 64)          |
|                                     |   'encoder_sequence_length_out', 64)    |                                         |
+-------------------------------------+-----------------------------------------+-----------------------------------------+
|                 Add                 |                   153                   |                    0                    |
|                Cast                 |                   41                    |                   21                    |
|               Concat                |                   173                   |                    0                    |
|              Constant               |                   762                   |                    0                    |
|           ConstantOfShape           |                    2                    |                    0                    |
|                 Div                 |                   34                    |                    0                    |
|                Equal                |                    1                    |                    0                    |
|                 Erf                 |                    8                    |                    0                    |
|               Expand                |                    1                    |                    0                    |
|               Gather                |                   56                    |                    0                    |
|                 If                  |                    1                    |                    1                    |
|                Less                 |                    1                    |                    0                    |
|               MatMul                |                   106                   |                    0                    |
|                 Mul                 |                   75                    |                    0                    |
|                 Pow                 |                   26                    |                    0                    |
|                Range                |                    1                    |                    0                    |
|             ReduceMean              |                   52                    |                    0                    |
|               Reshape               |                   164                   |                    0                    |
|                Shape                |                   57                    |                    0                    |
|                Slice                |                    4                    |                    0                    |
|               Softmax               |                   16                    |                    0                    |
|                Sqrt                 |                   26                    |                    0                    |
|               Squeeze               |                    2                    |                    0                    |
|                 Sub                 |                   27                    |                    0                    |
|              Transpose              |                   74                    |                    0                    |
|              Unsqueeze              |                   301                   |                    0                    |
|                Where                |                    2                    |                    0                    |
+-------------------------------------+-----------------------------------------+-----------------------------------------+
|             Model Size              |                56.90 MB                 |                 8.93 KB                 |
+-------------------------------------+-----------------------------------------+-----------------------------------------+
|            Elapsed Time             |                                      1.39 s                                       |
+-------------------------------------+-----------------------------------------+-----------------------------------------+
inisis commented 2 months ago

and another weird one: this model becomes empty?

wget https://huggingface.co/onnx-community/whisper-tiny_timestamped/resolve/ae48508b4bc9b594a3a84d21f4a365a29d8d66ad/onnx/decoder_model_merged_fp16.onnx

followed by:

onnxslim decoder_model_merged_fp16.onnx decoder_model_merged_fp16_slimmed.onnx

produces:

+-------------------------------------+-----------------------------------------+-----------------------------------------+
|             Model Name              |     decoder_model_merged_fp16.onnx      |               Op Set: 14                |
+-------------------------------------+-----------------------------------------+-----------------------------------------+
|             Model Info              |             Original Model              |              Slimmed Model              |
+-------------------------------------+-----------------------------------------+-----------------------------------------+
|            IN: input_ids            |          int64: ('batch_size',          |                                         |
|                                     |       'decoder_sequence_length')        |                                         |
|      IN: encoder_hidden_states      |         float32: ('batch_size',         |                                         |
|                                     |   'encoder_sequence_length / 2', 384)   |                                         |
|  IN: past_key_values.0.decoder.key  |       float32: ('batch_size', 6,        |                                         |
|                                     |   'past_decoder_sequence_length', 64)   |                                         |
| IN: past_key_values.0.decoder.value |       float32: ('batch_size', 6,        |                                         |
|                                     |   'past_decoder_sequence_length', 64)   |                                         |
|  IN: past_key_values.0.encoder.key  |       float32: ('batch_size', 6,        |                                         |
|                                     |   'encoder_sequence_length_out', 64)    |                                         |
| IN: past_key_values.0.encoder.value |       float32: ('batch_size', 6,        |                                         |
|                                     |   'encoder_sequence_length_out', 64)    |                                         |
|  IN: past_key_values.1.decoder.key  |       float32: ('batch_size', 6,        |                                         |
|                                     |   'past_decoder_sequence_length', 64)   |                                         |
| IN: past_key_values.1.decoder.value |       float32: ('batch_size', 6,        |                                         |
|                                     |   'past_decoder_sequence_length', 64)   |                                         |
|  IN: past_key_values.1.encoder.key  |       float32: ('batch_size', 6,        |                                         |
|                                     |   'encoder_sequence_length_out', 64)    |                                         |
| IN: past_key_values.1.encoder.value |       float32: ('batch_size', 6,        |                                         |
|                                     |   'encoder_sequence_length_out', 64)    |                                         |
|  IN: past_key_values.2.decoder.key  |       float32: ('batch_size', 6,        |                                         |
|                                     |   'past_decoder_sequence_length', 64)   |                                         |
| IN: past_key_values.2.decoder.value |       float32: ('batch_size', 6,        |                                         |
|                                     |   'past_decoder_sequence_length', 64)   |                                         |
|  IN: past_key_values.2.encoder.key  |       float32: ('batch_size', 6,        |                                         |
|                                     |   'encoder_sequence_length_out', 64)    |                                         |
| IN: past_key_values.2.encoder.value |       float32: ('batch_size', 6,        |                                         |
|                                     |   'encoder_sequence_length_out', 64)    |                                         |
|  IN: past_key_values.3.decoder.key  |       float32: ('batch_size', 6,        |                                         |
|                                     |   'past_decoder_sequence_length', 64)   |                                         |
| IN: past_key_values.3.decoder.value |       float32: ('batch_size', 6,        |                                         |
|                                     |   'past_decoder_sequence_length', 64)   |                                         |
|  IN: past_key_values.3.encoder.key  |       float32: ('batch_size', 6,        |                                         |
|                                     |   'encoder_sequence_length_out', 64)    |                                         |
| IN: past_key_values.3.encoder.value |       float32: ('batch_size', 6,        |                                         |
|                                     |   'encoder_sequence_length_out', 64)    |                                         |
|        IN: use_cache_branch         |               bool: (1,)                |               bool: (1,)                |
|     OUT: present.0.encoder.key      |       float32: ('batch_size', 6,        |         float32: (0, 6, 1, 64)          |
|                                     |   'encoder_sequence_length_out', 64)    |                                         |
|    OUT: present.3.encoder.value     |       float32: ('batch_size', 6,        |         float32: (0, 6, 1, 64)          |
|                                     |   'encoder_sequence_length_out', 64)    |                                         |
|             OUT: logits             |         float32: ('batch_size',         |         float32: ('batch_size',         |
|                                     |    'decoder_sequence_length', 51865)    |    'decoder_sequence_length', 51865)    |
|    OUT: present.3.decoder.value     |       float32: ('batch_size', 6,        |       float32: ('batch_size', 6,        |
|                                     | 'past_decoder_sequence_length + 1', 64) | 'past_decoder_sequence_length + 1', 64) |
|     OUT: present.3.encoder.key      |       float32: ('batch_size', 6,        |         float32: (0, 6, 1, 64)          |
|                                     |   'encoder_sequence_length_out', 64)    |                                         |
|     OUT: present.2.encoder.key      |       float32: ('batch_size', 6,        |         float32: (0, 6, 1, 64)          |
|                                     |   'encoder_sequence_length_out', 64)    |                                         |
|    OUT: present.0.decoder.value     |       float32: ('batch_size', 6,        |       float32: ('batch_size', 6,        |
|                                     | 'past_decoder_sequence_length + 1', 64) | 'past_decoder_sequence_length + 1', 64) |
|     OUT: present.1.encoder.key      |       float32: ('batch_size', 6,        |         float32: (0, 6, 1, 64)          |
|                                     |   'encoder_sequence_length_out', 64)    |                                         |
|     OUT: present.2.decoder.key      |       float32: ('batch_size', 6,        |       float32: ('batch_size', 6,        |
|                                     | 'past_decoder_sequence_length + 1', 64) | 'past_decoder_sequence_length + 1', 64) |
|       OUT: cross_attentions.1       |       float32: ('batch_size', 6,        |       float32: ('batch_size', 6,        |
|                                     |       'decoder_sequence_length',        |       'decoder_sequence_length',        |
|                                     |     'encoder_sequence_length_out')      |     'encoder_sequence_length_out')      |
|    OUT: present.1.decoder.value     |       float32: ('batch_size', 6,        |       float32: ('batch_size', 6,        |
|                                     | 'past_decoder_sequence_length + 1', 64) | 'past_decoder_sequence_length + 1', 64) |
|     OUT: present.0.decoder.key      |       float32: ('batch_size', 6,        |       float32: ('batch_size', 6,        |
|                                     | 'past_decoder_sequence_length + 1', 64) | 'past_decoder_sequence_length + 1', 64) |
|    OUT: present.2.decoder.value     |       float32: ('batch_size', 6,        |       float32: ('batch_size', 6,        |
|                                     | 'past_decoder_sequence_length + 1', 64) | 'past_decoder_sequence_length + 1', 64) |
|       OUT: cross_attentions.0       |       float32: ('batch_size', 6,        |       float32: ('batch_size', 6,        |
|                                     |       'decoder_sequence_length',        |       'decoder_sequence_length',        |
|                                     |     'encoder_sequence_length_out')      |     'encoder_sequence_length_out')      |
|       OUT: cross_attentions.2       |       float32: ('batch_size', 6,        |       float32: ('batch_size', 6,        |
|                                     |       'decoder_sequence_length',        |       'decoder_sequence_length',        |
|                                     |     'encoder_sequence_length_out')      |     'encoder_sequence_length_out')      |
|       OUT: cross_attentions.3       |       float32: ('batch_size', 6,        |       float32: ('batch_size', 6,        |
|                                     |       'decoder_sequence_length',        |       'decoder_sequence_length',        |
|                                     |     'encoder_sequence_length_out')      |     'encoder_sequence_length_out')      |
|    OUT: present.1.encoder.value     |       float32: ('batch_size', 6,        |         float32: (0, 6, 1, 64)          |
|                                     |   'encoder_sequence_length_out', 64)    |                                         |
|     OUT: present.3.decoder.key      |       float32: ('batch_size', 6,        |       float32: ('batch_size', 6,        |
|                                     | 'past_decoder_sequence_length + 1', 64) | 'past_decoder_sequence_length + 1', 64) |
|     OUT: present.1.decoder.key      |       float32: ('batch_size', 6,        |       float32: ('batch_size', 6,        |
|                                     | 'past_decoder_sequence_length + 1', 64) | 'past_decoder_sequence_length + 1', 64) |
|    OUT: present.0.encoder.value     |       float32: ('batch_size', 6,        |         float32: (0, 6, 1, 64)          |
|                                     |   'encoder_sequence_length_out', 64)    |                                         |
|    OUT: present.2.encoder.value     |       float32: ('batch_size', 6,        |         float32: (0, 6, 1, 64)          |
|                                     |   'encoder_sequence_length_out', 64)    |                                         |
+-------------------------------------+-----------------------------------------+-----------------------------------------+
|                 Add                 |                   153                   |                    0                    |
|                Cast                 |                   41                    |                   21                    |
|               Concat                |                   173                   |                    0                    |
|              Constant               |                   762                   |                    0                    |
|           ConstantOfShape           |                    2                    |                    0                    |
|                 Div                 |                   34                    |                    0                    |
|                Equal                |                    1                    |                    0                    |
|                 Erf                 |                    8                    |                    0                    |
|               Expand                |                    1                    |                    0                    |
|               Gather                |                   56                    |                    0                    |
|                 If                  |                    1                    |                    1                    |
|                Less                 |                    1                    |                    0                    |
|               MatMul                |                   106                   |                    0                    |
|                 Mul                 |                   75                    |                    0                    |
|                 Pow                 |                   26                    |                    0                    |
|                Range                |                    1                    |                    0                    |
|             ReduceMean              |                   52                    |                    0                    |
|               Reshape               |                   164                   |                    0                    |
|                Shape                |                   57                    |                    0                    |
|                Slice                |                    4                    |                    0                    |
|               Softmax               |                   16                    |                    0                    |
|                Sqrt                 |                   26                    |                    0                    |
|               Squeeze               |                    2                    |                    0                    |
|                 Sub                 |                   27                    |                    0                    |
|              Transpose              |                   74                    |                    0                    |
|              Unsqueeze              |                   301                   |                    0                    |
|                Where                |                    2                    |                    0                    |
+-------------------------------------+-----------------------------------------+-----------------------------------------+
|             Model Size              |                56.90 MB                 |                 8.93 KB                 |
+-------------------------------------+-----------------------------------------+-----------------------------------------+
|            Elapsed Time             |                                      1.39 s                                       |
+-------------------------------------+-----------------------------------------+-----------------------------------------+

it seems this model is invalid. you can check it here

>>> import onnx
>>> model = onnx.load('/root/decoder_model_merged_fp16.onnx')
>>> onnx.checker.check_model(model)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/vipuser/anaconda3/lib/python3.9/site-packages/onnx/checker.py", line 179, in check_model
    C.check_model(
onnx.onnx_cpp2py_export.checker.ValidationError: Nodes in a graph must be topologically sorted, however input 'graph_input_cast_1' of node: 
name: /model/decoder/layers.0/encoder_attn/k_proj/MatMul OpType: MatMul
 is not output of any previous nodes.

==> Context: Bad node spec for node. Name: optimum::if OpType: If
inisis commented 2 months ago

Thanks so much and no worries! 🤗

I've noticed another issue with quantization, where quantization after slimming the model does not work. Steps to reproduce:

  1. Download model
wget https://huggingface.co/onnx-community/whisper-tiny.en_timestamped/resolve/d4469fcf29fc2898f0d57632d811fa0ed21de5cc/onnx/decoder_model_merged.onnx
  1. Run ONNXSlim
onnxslim decoder_model_merged.onnx decoder_model_merged_slimmed.onnx
  1. Quantize the model
import onnx
from onnxruntime.quantization import quantize_dynamic, QuantType

model_fp32 = './decoder_model_merged_slimmed.onnx'
model_quant = './decoder_model_merged_slimmed_quantized.onnx'
quantized_model = quantize_dynamic(
    model_input=model_fp32,
    model_output=model_quant,
    weight_type=QuantType.QInt8,
    extra_options={'EnableSubgraph': True},
    per_channel=False,
    reduce_range=False,
)
  1. Check model size:
$ ls -l
-rw-r--r-- 1 root root 118606545 Jul  2 14:01 decoder_model_merged.onnx
-rw-r--r-- 1 root root 118662672 Jul  2 14:13 decoder_model_merged_slimmed.onnx
-rw-r--r-- 1 root root 110356751 Jul  2 14:14 decoder_model_merged_slimmed_quantized.onnx

However, if you were to run quantization on the original model, you'd get a much smaller output:

import onnx
from onnxruntime.quantization import quantize_dynamic, QuantType

model_fp32 = './decoder_model_merged.onnx'
model_quant = './decoder_model_merged_quantized.onnx'
quantized_model = quantize_dynamic(
    model_input=model_fp32,
    model_output=model_quant,
    weight_type=QuantType.QInt8,
    extra_options={'EnableSubgraph': True},
    per_channel=False,
    reduce_range=False,
)
-rw-r--r-- 1 root root  30791841 Jul  2 14:15 decoder_model_merged_quantized.onnx

Any idea what's going wrong? Thanks!

This seems a little bit complicated, can you help check the output correctness for the raw float model and the slimmed model,

inisis commented 2 months ago

Thanks so much and no worries! 🤗 I've noticed another issue with quantization, where quantization after slimming the model does not work. Steps to reproduce:

  1. Download model
wget https://huggingface.co/onnx-community/whisper-tiny.en_timestamped/resolve/d4469fcf29fc2898f0d57632d811fa0ed21de5cc/onnx/decoder_model_merged.onnx
  1. Run ONNXSlim
onnxslim decoder_model_merged.onnx decoder_model_merged_slimmed.onnx
  1. Quantize the model
import onnx
from onnxruntime.quantization import quantize_dynamic, QuantType

model_fp32 = './decoder_model_merged_slimmed.onnx'
model_quant = './decoder_model_merged_slimmed_quantized.onnx'
quantized_model = quantize_dynamic(
    model_input=model_fp32,
    model_output=model_quant,
    weight_type=QuantType.QInt8,
    extra_options={'EnableSubgraph': True},
    per_channel=False,
    reduce_range=False,
)
  1. Check model size:
$ ls -l
-rw-r--r-- 1 root root 118606545 Jul  2 14:01 decoder_model_merged.onnx
-rw-r--r-- 1 root root 118662672 Jul  2 14:13 decoder_model_merged_slimmed.onnx
-rw-r--r-- 1 root root 110356751 Jul  2 14:14 decoder_model_merged_slimmed_quantized.onnx

However, if you were to run quantization on the original model, you'd get a much smaller output:

import onnx
from onnxruntime.quantization import quantize_dynamic, QuantType

model_fp32 = './decoder_model_merged.onnx'
model_quant = './decoder_model_merged_quantized.onnx'
quantized_model = quantize_dynamic(
    model_input=model_fp32,
    model_output=model_quant,
    weight_type=QuantType.QInt8,
    extra_options={'EnableSubgraph': True},
    per_channel=False,
    reduce_range=False,
)
-rw-r--r-- 1 root root  30791841 Jul  2 14:15 decoder_model_merged_quantized.onnx

Any idea what's going wrong? Thanks!

This seems a little bit complicated, can you help check the output correctness for the raw float model and the slimmed model,

and I think the reason why the quantized slimmed model get larger is because it has tied weight that is not tied.

inisis commented 2 months ago

I investigated on this model, it's a little bit strange the quantized model still keeps the float weight. image

inisis commented 2 months ago

I have raised an issue here https://github.com/microsoft/onnxruntime/issues/21277, they suggest using quantization preprocess to resolve this issue

xenova commented 2 months ago

Thanks! 👍

I ran into another issue for this model:

wget https://github.com/pengzhendong/pyannote-onnx/raw/master/pyannote_onnx/segmentation-3.0.onnx
onnxslim segmentation-3.0.onnx slimmed.onnx

produces an invalid model:

+-----------------------+-------------------------------------+-------------------------------------+
|      Model Name       |        segmentation-3.0.onnx        |             Op Set: 17              |
+-----------------------+-------------------------------------+-------------------------------------+
|      Model Info       |           Original Model            |            Slimmed Model            |
+-----------------------+-------------------------------------+-------------------------------------+
|       IN: input       |      float32: ('B', 'C', 'T')       |      float32: ('B', 'C', 'T')       |
|      OUT: output      | float32: ('LogSoftmaxoutput_dim_0', | float32: ('LogSoftmaxoutput_dim_0', |
|                       |    'LogSoftmaxoutput_dim_1', 7)     |    'LogSoftmaxoutput_dim_1', 7)     |
+-----------------------+-------------------------------------+-------------------------------------+
|          Abs          |                  1                  |                  1                  |
|          Add          |                  3                  |                  3                  |
|        Concat         |                  3                  |                  3                  |
|    ConstantOfShape    |                  1                  |                  1                  |
|         Conv          |                  4                  |                  4                  |
|         Equal         |                  1                  |                  1                  |
|        Gather         |                  3                  |                  3                  |
|          If           |                  1                  |                  1                  |
| InstanceNormalization |                  4                  |                  4                  |
|         LSTM          |                  4                  |                  4                  |
|       LeakyRelu       |                  5                  |                  5                  |
|      LogSoftmax       |                  1                  |                  1                  |
|        MatMul         |                  3                  |                  3                  |
|        MaxPool        |                  3                  |                  3                  |
|        Reshape        |                  6                  |                  6                  |
|         Shape         |                  3                  |                  3                  |
|         Slice         |                  7                  |                  7                  |
|        Squeeze        |                  1                  |                  1                  |
|       Transpose       |                  6                  |                  6                  |
|       Unsqueeze       |                  2                  |                  2                  |
+-----------------------+-------------------------------------+-------------------------------------+
|      Model Size       |               5.71 MB               |              90.05 KB               |
+-----------------------+-------------------------------------+-------------------------------------+
|     Elapsed Time      |                                  0.21 s                                   |
+-----------------------+-------------------------------------+-------------------------------------+

Running

import onnx
onnx_model = onnx.load("slimmed.onnx")
onnx.checker.check_model(onnx_model)

produces

ValidationError: Nodes in a graph must be topologically sorted, however input 'ortshared_1_1_1_1_token_110' of node: 
name: /sincnet/wav_norm1d/InstanceNormalization OpType: InstanceNormalization
 is not output of any previous nodes.

Any idea what the problem is? Thanks!

inisis commented 2 months ago

@xenova you can try pip install git+https://github.com/inisis/OnnxSlim@main, I have tested it, and I will release a new version tonight

xenova commented 2 months ago

Thanks - I did try that, but still same issue 👀 The last commit was ~1 week ago, correct?

inisis commented 2 months ago
(base) root@ubuntu20:~# pip install git+https://github.com/inisis/OnnxSlim@main
Looking in indexes: https://mirror.sjtu.edu.cn/pypi/web/simple
Collecting git+https://github.com/inisis/OnnxSlim@main
  Cloning https://github.com/inisis/OnnxSlim (to revision main) to /tmp/pip-req-build-xra0qpjh
  Running command git clone --filter=blob:none --quiet https://github.com/inisis/OnnxSlim /tmp/pip-req-build-xra0qpjh
  Resolved https://github.com/inisis/OnnxSlim to commit cdd20b9fbca86d1f40bc87bad60bb96b64fe3a1a
  Preparing metadata (setup.py) ... done
Requirement already satisfied: onnx in ./miniconda3/lib/python3.9/site-packages (from onnxslim==0.1.31) (1.16.1)
Requirement already satisfied: sympy in ./miniconda3/lib/python3.9/site-packages (from onnxslim==0.1.31) (1.12)
Requirement already satisfied: packaging in ./miniconda3/lib/python3.9/site-packages (from onnxslim==0.1.31) (23.2)
Requirement already satisfied: numpy>=1.20 in ./miniconda3/lib/python3.9/site-packages (from onnx->onnxslim==0.1.31) (1.26.3)
Requirement already satisfied: protobuf>=3.20.2 in ./miniconda3/lib/python3.9/site-packages (from onnx->onnxslim==0.1.31) (5.27.2)
Requirement already satisfied: mpmath>=0.19 in ./miniconda3/lib/python3.9/site-packages (from sympy->onnxslim==0.1.31) (1.3.0)
Building wheels for collected packages: onnxslim
  Building wheel for onnxslim (setup.py) ... done
  Created wheel for onnxslim: filename=onnxslim-0.1.31-py3-none-any.whl size=130494 sha256=fab9ca8d0e8d779b8ccafac9a02995aa881291645786707685fd00e11730dad2
  Stored in directory: /tmp/pip-ephem-wheel-cache-g0i_avbp/wheels/a4/f1/26/eeb33d2410214c343b741e995e8e846d16aec448bc8421d568
Successfully built onnxslim
Installing collected packages: onnxslim
Successfully installed onnxslim-0.1.31
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
(base) root@ubuntu20:~# onnxslim OnnxSlim/bin/segmentation-3.0.onnx slim.onnx
+-----------------------+-------------------------------------+-------------------------------------+
|      Model Name       |        segmentation-3.0.onnx        |             Op Set: 17              |
+-----------------------+-------------------------------------+-------------------------------------+
|      Model Info       |           Original Model            |            Slimmed Model            |
+-----------------------+-------------------------------------+-------------------------------------+
|       IN: input       |      float32: ('B', 'C', 'T')       |      float32: ('B', 'C', 'T')       |
|      OUT: output      | float32: ('LogSoftmaxoutput_dim_0', | float32: ('LogSoftmaxoutput_dim_0', |
|                       |    'LogSoftmaxoutput_dim_1', 7)     |    'LogSoftmaxoutput_dim_1', 7)     |
+-----------------------+-------------------------------------+-------------------------------------+
|          Abs          |                  1                  |                  1                  |
|          Add          |                  3                  |                  3                  |
|        Concat         |                  3                  |                  3                  |
|    ConstantOfShape    |                  1                  |                  1                  |
|         Conv          |                  4                  |                  4                  |
|         Equal         |                  1                  |                  1                  |
|        Gather         |                  3                  |                  3                  |
|          If           |                  1                  |                  1                  |
| InstanceNormalization |                  4                  |                  4                  |
|         LSTM          |                  4                  |                  4                  |
|       LeakyRelu       |                  5                  |                  5                  |
|      LogSoftmax       |                  1                  |                  1                  |
|        MatMul         |                  3                  |                  3                  |
|        MaxPool        |                  3                  |                  3                  |
|        Reshape        |                  6                  |                  6                  |
|         Shape         |                  3                  |                  3                  |
|         Slice         |                  7                  |                  7                  |
|        Squeeze        |                  1                  |                  1                  |
|       Transpose       |                  6                  |                  6                  |
|       Unsqueeze       |                  2                  |                  2                  |
+-----------------------+-------------------------------------+-------------------------------------+
|      Model Size       |               5.71 MB               |               5.71 MB               |
+-----------------------+-------------------------------------+-------------------------------------+
|     Elapsed Time      |                                  0.53 s                                   |
+-----------------------+-------------------------------------+-------------------------------------+
inisis commented 2 months ago

Hi @xenova I have created a repo called OnnxLLM, specialing in onnxruntime llm inference, currently supported models are llama3, qwen2, chatglm3, I see that chatglm is not supported currently in your repo, can we work together to intergrate it. Thanks!

inisis commented 2 months ago

Hi @xenova I have made a pr in https://github.com/huggingface/optimum/issues/1744, can you help back me up. I think it's a very useful tool.