Open lakhinderwalia opened 2 weeks ago
https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.MatMulNBits according to the spec zero point can only be uint8/int32/float16/float?
https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.MatMulNBits according to the spec zero point can only be uint8/int32/float16/float?
And if the zero point isn't specified, it can be inferred to be uint8
or int32
, would be my guess. I should fix my test case, and do it for int32
instead of int8
.
@kahmed10, btw, I don't see any type checking in this parser code -- unless it is done somewhere else.
https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.MatMulNBits according to the spec zero point can only be uint8/int32/float16/float?
And if the zero point isn't specified, it can be inferred to be
uint8
orint32
, would be my guess. I should fix my test case, and do it forint32
instead ofint8
.@kahmed10, btw, I don't see any type checking in this parser code -- unless it is done somewhere else.
I don't think we can assume it is handled. If things are type constrained you'll need to add that for those inputs
https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.MatMulNBits according to the spec zero point can only be uint8/int32/float16/float?
And if the zero point isn't specified, it can be inferred to be
uint8
orint32
, would be my guess. I should fix my test case, and do it forint32
instead ofint8
. @kahmed10, btw, I don't see any type checking in this parser code -- unless it is done somewhere else.I don't think we can assume it is handled. If things are type constrained you'll need to add that for those inputs
Ok. Let me enhance this PR to include the basic type checking for this operator.
Test | Batch | Rate new 14084c |
Rate old 71fd27 |
Diff | Compare | |
---|---|---|---|---|---|---|
torchvision-resnet50 | 64 | 3,259.22 | 3,258.47 | 0.02% | :white_check_mark: | |
torchvision-resnet50_fp16 | 64 | nan | 6,989.35 | nan% | :x: | |
torchvision-densenet121 | 32 | 2,435.53 | 2,437.00 | -0.06% | :white_check_mark: | |
torchvision-densenet121_fp16 | 32 | nan | 4,070.38 | nan% | :x: | |
torchvision-inceptionv3 | 32 | 1,638.28 | 1,639.85 | -0.10% | :white_check_mark: | |
torchvision-inceptionv3_fp16 | 32 | nan | 2,763.41 | nan% | :x: | |
cadene-inceptionv4 | 16 | 776.15 | 776.50 | -0.04% | :white_check_mark: | |
cadene-resnext64x4 | 16 | 811.84 | 808.24 | 0.44% | :white_check_mark: | |
slim-mobilenet | 64 | 7,536.31 | 7,538.65 | -0.03% | :white_check_mark: | |
slim-nasnetalarge | 64 | 211.48 | 211.54 | -0.03% | :white_check_mark: | |
slim-resnet50v2 | 64 | nan | 3,507.21 | nan% | :x: | |
bert-mrpc-onnx | 8 | 1,151.20 | 1,150.51 | 0.06% | :white_check_mark: | |
bert-mrpc-tf | 1 | 499.42 | 475.19 | 5.10% | :high_brightness: | |
pytorch-examples-wlang-gru | 1 | 476.40 | 426.61 | 11.67% | :high_brightness: | |
pytorch-examples-wlang-lstm | 1 | 382.40 | 376.26 | 1.63% | :white_check_mark: | |
torchvision-resnet50_1 | 1 | 782.58 | 785.19 | -0.33% | :white_check_mark: | |
cadene-dpn92_1 | 1 | 398.76 | 399.09 | -0.08% | :white_check_mark: | |
cadene-resnext101_1 | 1 | 383.83 | 383.01 | 0.21% | :white_check_mark: | |
onnx-taau-downsample | 1 | nan | 343.03 | nan% | :x: | |
dlrm-criteoterabyte | 1 | 33.34 | 33.33 | 0.05% | :white_check_mark: | |
dlrm-criteoterabyte_fp16 | 1 | 52.71 | 52.73 | -0.04% | :white_check_mark: | |
agentmodel | 1 | 8,200.96 | 8,178.96 | 0.27% | :white_check_mark: | |
unet_fp16 | 2 | nan | 58.92 | nan% | :x: | |
resnet50v1_fp16 | 1 | nan | 925.30 | nan% | :x: | |
resnet50v1_int8 | 1 | nan | 1,011.55 | nan% | :x: | |
bert_base_cased_fp16 | 64 | nan | 1,169.93 | nan% | :x: | |
bert_large_uncased_fp16 | 32 | nan | 363.31 | nan% | :x: | |
bert_large_fp16 | 1 | nan | 200.50 | nan% | :x: | |
distilgpt2_fp16 | 16 | nan | 2,194.69 | nan% | :x: | |
yolov5s | 1 | 545.17 | 533.06 | 2.27% | :white_check_mark: | |
tinyllama | 1 | nan | 43.45 | nan% | :x: | |
vicuna-fastchat | 1 | 170.69 | 172.29 | -0.92% | :white_check_mark: | |
whisper-tiny-encoder | 1 | 417.98 | 417.95 | 0.01% | :white_check_mark: | |
whisper-tiny-decoder | 1 | nan | 425.65 | nan% | :x: |
This build is not recommended to merge :red_circle:
:x:bert_base_cased_fp16: ERROR - check error output
:x:bert_large_uncased_fp16: ERROR - check error output
:x:tinyllama: ERROR - check error output
:x:whisper-tiny-decoder: ERROR - check error output
:x:distilgpt2_fp16: ERROR - check error output
Is this still needed? Too many failures to approve
These failures have likely nothing to do with this change, @causten, because I added a fix for an int8
type, which isn't exercised anywhere in the codebase. But this PR is still important to debug our current fusion issues, as I am debugging one right now. It is kept as a back-burner item. Thanks.
Currently the matmulnbits parsing introduces a fixed type
uint8_type
for zero_point, and missesint8_type
.