ROCm / AMDMIGraphX

AMD's graph optimization engine.
https://rocm.docs.amd.com/projects/AMDMIGraphX/en/latest/
MIT License
185 stars 86 forks source link

matmulnbits zero_point fix #3566

Open lakhinderwalia opened 2 weeks ago

lakhinderwalia commented 2 weeks ago

Currently the matmulnbits parsing introduces a fixed type uint8_type for zero_point, and misses int8_type.

kahmed10 commented 2 weeks ago

https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.MatMulNBits according to the spec zero point can only be uint8/int32/float16/float?

lakhinderwalia commented 2 weeks ago

https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.MatMulNBits according to the spec zero point can only be uint8/int32/float16/float?

And if the zero point isn't specified, it can be inferred to be uint8 or int32, would be my guess. I should fix my test case, and do it for int32 instead of int8.

@kahmed10, btw, I don't see any type checking in this parser code -- unless it is done somewhere else.

TedThemistokleous commented 2 weeks ago

https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.MatMulNBits according to the spec zero point can only be uint8/int32/float16/float?

And if the zero point isn't specified, it can be inferred to be uint8 or int32, would be my guess. I should fix my test case, and do it for int32 instead of int8.

@kahmed10, btw, I don't see any type checking in this parser code -- unless it is done somewhere else.

I don't think we can assume it is handled. If things are type constrained you'll need to add that for those inputs

lakhinderwalia commented 2 weeks ago

https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.MatMulNBits according to the spec zero point can only be uint8/int32/float16/float?

And if the zero point isn't specified, it can be inferred to be uint8 or int32, would be my guess. I should fix my test case, and do it for int32 instead of int8. @kahmed10, btw, I don't see any type checking in this parser code -- unless it is done somewhere else.

I don't think we can assume it is handled. If things are type constrained you'll need to add that for those inputs

Ok. Let me enhance this PR to include the basic type checking for this operator.

migraphx-bot commented 2 weeks ago
Test Batch Rate new
14084c
Rate old
71fd27
Diff Compare
torchvision-resnet50 64 3,259.22 3,258.47 0.02% :white_check_mark:
torchvision-resnet50_fp16 64 nan 6,989.35 nan% :x:
torchvision-densenet121 32 2,435.53 2,437.00 -0.06% :white_check_mark:
torchvision-densenet121_fp16 32 nan 4,070.38 nan% :x:
torchvision-inceptionv3 32 1,638.28 1,639.85 -0.10% :white_check_mark:
torchvision-inceptionv3_fp16 32 nan 2,763.41 nan% :x:
cadene-inceptionv4 16 776.15 776.50 -0.04% :white_check_mark:
cadene-resnext64x4 16 811.84 808.24 0.44% :white_check_mark:
slim-mobilenet 64 7,536.31 7,538.65 -0.03% :white_check_mark:
slim-nasnetalarge 64 211.48 211.54 -0.03% :white_check_mark:
slim-resnet50v2 64 nan 3,507.21 nan% :x:
bert-mrpc-onnx 8 1,151.20 1,150.51 0.06% :white_check_mark:
bert-mrpc-tf 1 499.42 475.19 5.10% :high_brightness:
pytorch-examples-wlang-gru 1 476.40 426.61 11.67% :high_brightness:
pytorch-examples-wlang-lstm 1 382.40 376.26 1.63% :white_check_mark:
torchvision-resnet50_1 1 782.58 785.19 -0.33% :white_check_mark:
cadene-dpn92_1 1 398.76 399.09 -0.08% :white_check_mark:
cadene-resnext101_1 1 383.83 383.01 0.21% :white_check_mark:
onnx-taau-downsample 1 nan 343.03 nan% :x:
dlrm-criteoterabyte 1 33.34 33.33 0.05% :white_check_mark:
dlrm-criteoterabyte_fp16 1 52.71 52.73 -0.04% :white_check_mark:
agentmodel 1 8,200.96 8,178.96 0.27% :white_check_mark:
unet_fp16 2 nan 58.92 nan% :x:
resnet50v1_fp16 1 nan 925.30 nan% :x:
resnet50v1_int8 1 nan 1,011.55 nan% :x:
bert_base_cased_fp16 64 nan 1,169.93 nan% :x:
bert_large_uncased_fp16 32 nan 363.31 nan% :x:
bert_large_fp16 1 nan 200.50 nan% :x:
distilgpt2_fp16 16 nan 2,194.69 nan% :x:
yolov5s 1 545.17 533.06 2.27% :white_check_mark:
tinyllama 1 nan 43.45 nan% :x:
vicuna-fastchat 1 170.69 172.29 -0.92% :white_check_mark:
whisper-tiny-encoder 1 417.98 417.95 0.01% :white_check_mark:
whisper-tiny-decoder 1 nan 425.65 nan% :x:

This build is not recommended to merge :red_circle:

migraphx-bot commented 2 weeks ago


     :white_check_mark: bert-mrpc-onnx: PASSED: MIGraphX meets tolerance
     :white_check_mark: bert-mrpc-tf: PASSED: MIGraphX meets tolerance
     :white_check_mark: pytorch-examples-wlang-gru: PASSED: MIGraphX meets tolerance
     :white_check_mark: pytorch-examples-wlang-lstm: PASSED: MIGraphX meets tolerance
     :white_check_mark: torchvision-resnet50_1: PASSED: MIGraphX meets tolerance
     :white_check_mark: cadene-dpn92_1: PASSED: MIGraphX meets tolerance
     :white_check_mark: cadene-resnext101_1: PASSED: MIGraphX meets tolerance
     :white_check_mark: dlrm-criteoterabyte: PASSED: MIGraphX meets tolerance
     :white_check_mark: agentmodel: PASSED: MIGraphX meets tolerance
     :white_check_mark: unet: PASSED: MIGraphX meets tolerance
     :white_check_mark: resnet50v1: PASSED: MIGraphX meets tolerance
:x:bert_base_cased_fp16: ERROR - check error output

:x:bert_large_uncased_fp16: ERROR - check error output

     :white_check_mark: bert_large: PASSED: MIGraphX meets tolerance
     :white_check_mark: yolov5s: PASSED: MIGraphX meets tolerance
:x:tinyllama: ERROR - check error output

     :white_check_mark: vicuna-fastchat: PASSED: MIGraphX meets tolerance
     :white_check_mark: whisper-tiny-encoder: PASSED: MIGraphX meets tolerance
:x:whisper-tiny-decoder: ERROR - check error output

:x:distilgpt2_fp16: ERROR - check error output

lakhinderwalia commented 2 days ago

Is this still needed? Too many failures to approve

These failures have likely nothing to do with this change, @causten, because I added a fix for an int8 type, which isn't exercised anywhere in the codebase. But this PR is still important to debug our current fusion issues, as I am debugging one right now. It is kept as a back-burner item. Thanks.