robertknight / rten

ONNX neural network inference engine
76 stars 3 forks source link

Convert quantized models #42

Open igor-yusupov opened 6 months ago

igor-yusupov commented 6 months ago

Is it possible to convert quantized models? I tried and got error: WARNING: Error converting initializer: Unsupported tensor data type uint8 for operator None

igor-yusupov commented 6 months ago

upd: I have changed uint32 to int32 and got errors:

Unsupported operator MatMulInteger
Unsupported operator DynamicQuantizeLinear
robertknight commented 6 months ago

There is no support for quantization yet. This is something I plan to add in future. Currently f32 / i32 are the only data types supported (i64 and bool tensors are also "supported", but they are converted to i32).

igor-yusupov commented 6 months ago

I got it. I would be happy to help if needed.

robertknight commented 6 months ago

Can you provide some details on the models you are testing with (eg. is it a well known public model or based on one?) and the tools/processes you are using to quantize the model?

igor-yusupov commented 6 months ago

yes, sure.

I want to run quantized models based on transformers: https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html I also can send you example weights of the models.

I use python lib onnxruntime for quantization. There is some example how to use it

from onnxruntime.quantization import quantize_dynamic, QuantType

quantize_dynamic(FP32_PATH, INT32_PATH, weight_type=QuantType.QUInt8)
quantize_dynamic(FP32_PATH, INT32_PATH, weight_type=QuantType.QInt8)
igor-yusupov commented 5 months ago

attached weights: encoder: https://www.dropbox.com/scl/fi/qzdjxh0t6wgokmmrvb7a6/encoder.quant.onnx?rlkey=mp06v3xw2imq1ips1ligh5yu6&dl=0 decoder: https://www.dropbox.com/scl/fi/asrlmk15242b1dqbm9qvn/decoder.quant.onnx?rlkey=2i1f4ek50h25cyil0jhr70rqc&dl=0 generator: https://www.dropbox.com/scl/fi/wo4htil4kty2g25qeu33o/generator.quant.onnx?rlkey=ycs0bw7o1p17jgxf9s15j0e1s&dl=0

igor-yusupov commented 5 months ago

Returning to the conversation about how to add ability to use quantized models: do I understand correctly that there is needed to add the ability to use MatMulInteger and DynamicQuantizeLinear operators?

I see commit where you added ReduceSumSquare operator https://github.com/robertknight/rten/commit/36bbf89f6b7c92730ecdf437c2dd4bb3e315ee9c. Is it enough to add these operators similarly? Or is it a very complicated task and I'd rather wait for you to add this feature? :)

robertknight commented 4 months ago

Before doing any work in rten itself, I think the starting point will be to pick a couple of test cases of different kinds of model (simple CNN and a transformer models) and create working demos using ONNX Runtime (or some other engine), plus a working fp32 version with rten to compare. This will provide a baseline for the expected accuracy and performance impact of going from fp32 to int8. This part doesn't require in-depth knowledge of the existing rten code.

To implement quantization support in rten itself, there are a few pieces involved:

To support integer operators in operators, the most direct approach would be to convert MatMulInteger and DynamicQuantizeLinear ONNX ops to matching rten operators. This is how the library mostly handles ONNX operators today. ONNX Runtime however does a bunch of graph optimizations before it actually executes the model. I'm not sure yet whether it will make sense for rten to do some of these as part of model conversion or defer to runtime.

In terms of actually implementing an operator like MatMulInteger, it is easy to create a naive version which is very slow, but creating one that is fast involves more work, including architecture-specific code. This might be done by making parts of the existing gemm implementation for fp32 matrix multiplication generic and pluggable, or it might make more sense to have a separate code path. I'm not sure. Either way, it will involve a fair amount of work.

igor-yusupov commented 4 months ago

Got it, it sounds like a lot of work will have to be done. But if it works out, rten will be a very powerful tool for deploying neural networks.

I can help by starting to do the first point, that is, make a test bench where it will be possible to run inference models using onnxruntime and rten to compare inference speeds. How would you like to see it implemented? Should this be a separate repository or should this be added to the rten repository? How do you see the structure of the project?

robertknight commented 4 months ago

Should this be a separate repository or should this be added to the rten repository?

I think a separate repository would be easiest. The structure doesn't matter too much, as long as it is easy for someone else to run.

igor-yusupov commented 4 months ago

Hi! I have created repo and added whisper example: https://github.com/igor-yusupov/comparisons-rten/tree/main/whisper

So you can check by running whisper model with rten and onnxruntime (I also added quant models)

Overall rten is pretty fast, only 1.5 times slower than onnxruntime. But if it will be possible to use quantized models, it will be very cool!

robertknight commented 4 months ago

This is awesome, thank you so much :) - A Whisper demo is super useful to have.

robertknight commented 4 months ago

I did some benchmarks on my 2020 Intel MacBook Pro (Intel i5-1038NG7 @ 2.00GHz) to get a sense of where performance is at.

To summarize for others reading this, the task here is transcribing a 1m58s audio file using the Whisper base model (links).

These results are in line with my expectations. The output with RTEN_TIMING for the encoder (1 run):

Graph run of 520 ops finished in 2730.454ms
MatMul           1813.05ms (66.40%)
[Mem alloc/free] 251.49ms  (9.21%)
Div              104.13ms  (3.81%)
Softmax          98.64ms   (3.61%)
Add              97.01ms   (3.55%)
Transpose        95.67ms   (3.50%)
Mul              82.57ms   (3.02%)
Conv             46.29ms   (1.70%)
Erf              42.47ms   (1.56%)
Sub              28.43ms   (1.04%)
Pow              27.85ms   (1.02%)
ReduceMean       25.12ms   (0.92%)
Reshape          15.21ms   (0.56%)
[Other]          1.57ms    (0.06%)
Slice            0.31ms    (0.01%)
Gather           0.27ms    (0.01%)
Concat           0.12ms    (0.00%)
Shape            0.11ms    (0.00%)
Unsqueeze        0.07ms    (0.00%)
Cast             0.04ms    (0.00%)
Sqrt             0.03ms    (0.00%)

And an example of a decoder run (288 runs):

Graph run of 1622 ops finished in 88.543ms
MatMul           59.78ms (67.52%)
Transpose        17.55ms (19.82%)
[Other]          2.65ms  (3.00%)
Add              1.63ms  (1.85%)
Concat           1.63ms  (1.84%)
Mul              1.52ms  (1.72%)
Reshape          0.95ms  (1.07%)
[Mem alloc/free] 0.80ms  (0.90%)
Softmax          0.60ms  (0.68%)
Gather           0.26ms  (0.29%)
ScatterND        0.16ms  (0.18%)
Unsqueeze        0.15ms  (0.17%)
Shape            0.12ms  (0.14%)
Div              0.12ms  (0.14%)
Slice            0.10ms  (0.11%)
Where            0.08ms  (0.09%)
ReduceMean       0.08ms  (0.09%)
Pow              0.08ms  (0.09%)
Sub              0.07ms  (0.08%)
Range            0.06ms  (0.07%)
Expand           0.03ms  (0.04%)
Cast             0.03ms  (0.04%)
Erf              0.02ms  (0.03%)
Sqrt             0.02ms  (0.02%)
Trilu            0.02ms  (0.02%)
Equal            0.01ms  (0.01%)
ConstantOfShape  0.01ms  (0.01%)

And top outputs for RTEN_TIMING=by-shape=1:

For the encoder:

MatMul           1686.08ms (66.33%)

    Shape                                 Count  Mean (ms)  Total (ms)  ns/input elem
    ------------------------------------  -----  ---------  ----------  -------------
    [4, 8, 1500, 64], [4, 8, 64, 1500]    6      110.186    661.116     17.934
    [4, 1500, 2048], [2048, 512]          6      53.166     318.997     3.986
    [4, 1500, 512], [512, 512]            24     10.645     255.484     3.193
    [4, 1500, 512], [512, 2048]           6      41.010     246.058     9.952
    [4, 8, 1500, 1500], [4, 8, 1500, 64]  6      34.071     204.429     0.454

For one example of a decoder run:

MatMul           58.23ms (67.13%)

    Shape                              Count  Mean (ms)  Total (ms)  ns/input elem
    ---------------------------------  -----  ---------  ----------  -------------
    [1, 1500, 512], [512, 512]         12     2.835      34.020      2.752
    [1, 1, 512], [512, 51865]          1      8.758      8.758       0.330
    [1, 1, 512], [512, 512]            36     0.165      5.924       0.627
    [1, 1, 512], [512, 2048]           6      0.417      2.504       0.398
    [1, 1, 2048], [2048, 512]          6      0.336      2.015       0.320
    [1, 8, 1, 1500], [1, 8, 1500, 64]  6      0.302      1.814       0.388
    [1, 8, 1, 64], [1, 8, 64, 1500]    6      0.228      1.369       0.297
    [1, 8, 1, 64], [1, 8, 64, 287]     6      0.169      1.011       1.143
    [1, 8, 1, 287], [1, 8, 287, 64]    6      0.136      0.816       0.911

Transpose        17.80ms (20.52%)

    Shape             Count  Mean (ms)  Total (ms)  ns/input elem
    ----------------  -----  ---------  ----------  -------------
    [1, 1500, 8, 64]  12     1.080      12.956      1.406
    [1, 287, 8, 64]   12     0.396      4.754       2.696
    [1, 8, 1, 64]     12     0.004      0.047       7.650
    [1, 1, 8, 64]     12     0.004      0.043       6.999

So from a quick initial look, it seems that improvements will most likely come from:

igor-yusupov commented 4 months ago

Yeah, getting close to the speed of whisper.cpp would be really cool :) Thank you for performance review.

btw I have a fp16_quant script, it also reduces weights's sizes. Is it worth trying to add fp16 support or is it better to focus on int8 support?

Is there anything I can do to help?

robertknight commented 4 months ago

btw I have a fp16_quant script, it also reduces weights's sizes.

Thanks, I saw that. I notice that the whisper-py demo (using onnxruntime) is much slower with the fp16 model than fp32 (~34s vs ~18s) on my Intel i5 laptop. That chip doesn't have f16 support for most instructions, although it does support F16C instructions for fast f16 <-> f32 conversion. I guess whisper.cpp is being much smarter about how/when it converts between f16 and f32 under Intel.

Might be worth testing on a modern ARM device with native fp16 support to see how it compares to fp32.

Is it worth trying to add fp16 support or is it better to focus on int8 support?

int8 was initially more interesting for me because it has better support in WebAssembly. That said f16 / bf16 have the advantage of being simpler, especially where the hardware has native support.

igor-yusupov commented 4 months ago

I tried fp16 on new macbook's ARM processor and it also works slower than fp32. I agree with that int8 more interesting because it reduces weights more than fp16. Thank you!

robertknight commented 4 months ago

I tried fp16 on new macbook's ARM processor and it also works slower than fp32.

This is surprising. I would have expected native fp16 to be faster. I wonder if that is due to the hardware or something with how ONNX Runtime is executing the model. Under Whisper.cpp, I see a big difference in performance (like ~4x) if native fp16 <-> fp32 conversion is enabled (via F16C instructions) vs without.