ROCm / AMDMIGraphX

AMD's graph optimization engine.
https://rocm.docs.amd.com/projects/AMDMIGraphX/en/latest/
MIT License
185 stars 86 forks source link

Enable GEMM/dot for FP8 using hipblasLT #3577

Closed CharlieL7 closed 3 days ago

CharlieL7 commented 2 weeks ago

Looks to work correctly on Navi4X

codecov[bot] commented 2 weeks ago

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Project coverage is 92.17%. Comparing base (1cfd6c2) to head (519a63b). Report is 5 commits behind head on develop.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## develop #3577 +/- ## =========================================== - Coverage 92.17% 92.17% -0.01% =========================================== Files 513 513 Lines 21560 21558 -2 =========================================== - Hits 19873 19871 -2 Misses 1687 1687 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

CharlieL7 commented 2 weeks ago

Approved. Is there any test case that verifies this set of Ops. Thanks.

Yes, there are the test_verify gemm tests that will use hipblaslt now if the flag is enabled.

CharlieL7 commented 1 week ago

Created an issue tracking the preprocessor conditional I added: https://github.com/ROCm/AMDMIGraphX/issues/3592

CharlieL7 commented 1 week ago

I tried using CMake's check_type_size or check_symbol_exists to detect if the types are in the hipblaslt version. The HIP_R_F8...are not really types so it doesn't work. I didn't get check_symbol_exists to work. The last thing I found was using check_source_compiles but that's way too troublesome to use.

pfultz2 commented 1 week ago

The types are defined as enum in hip not in hipblaslt:

https://github.com/ROCm/HIP/blob/3d60bd3a6415c280bd1fe63767ae8e10eea4d2d1/include/hip/library_types.h#L61

So we shouldn't check the hipblaslt version. It looks like hipblaslt already checks the hip version and then defines ROCM_USE_FLOAT8:

https://github.com/ROCm/hipBLASLt/blob/b2adca84509dd31e31b8f42044389128d199b62e/library/include/hipblaslt.h#L67

Which looks like it is to enable the non-fnuz types:

https://github.com/ROCm/hipBLASLt/blob/b2adca84509dd31e31b8f42044389128d199b62e/library/src/amd_detail/rocblaslt/src/utility.cpp#L79

So we can probably use #ifdef ROCM_USE_FLOAT8 for this.

CharlieL7 commented 1 week ago

Works well with the variable

migraphx-bot commented 3 days ago
Test Batch Rate new
519a63
Rate old
c51bea
Diff Compare
torchvision-resnet50 64 3,260.35 3,257.81 0.08% :white_check_mark:
torchvision-resnet50_fp16 64 6,992.56 6,987.81 0.07% :white_check_mark:
torchvision-densenet121 32 2,435.37 2,434.57 0.03% :white_check_mark:
torchvision-densenet121_fp16 32 4,052.07 4,065.61 -0.33% :white_check_mark:
torchvision-inceptionv3 32 1,628.86 1,637.17 -0.51% :white_check_mark:
torchvision-inceptionv3_fp16 32 2,745.48 2,759.26 -0.50% :white_check_mark:
cadene-inceptionv4 16 766.04 776.31 -1.32% :white_check_mark:
cadene-resnext64x4 16 810.66 811.75 -0.13% :white_check_mark:
slim-mobilenet 64 7,468.29 7,533.16 -0.86% :white_check_mark:
slim-nasnetalarge 64 208.48 211.39 -1.38% :white_check_mark:
slim-resnet50v2 64 3,440.40 3,504.83 -1.84% :white_check_mark:
bert-mrpc-onnx 8 1,151.09 1,146.47 0.40% :white_check_mark:
bert-mrpc-tf 1 465.75 473.89 -1.72% :white_check_mark:
pytorch-examples-wlang-gru 1 419.90 425.31 -1.27% :white_check_mark:
pytorch-examples-wlang-lstm 1 379.35 408.68 -7.18% :red_circle:
torchvision-resnet50_1 1 820.74 771.75 6.35% :high_brightness:
cadene-dpn92_1 1 400.14 399.01 0.28% :white_check_mark:
cadene-resnext101_1 1 382.79 383.85 -0.28% :white_check_mark:
onnx-taau-downsample 1 346.22 343.09 0.91% :white_check_mark:
dlrm-criteoterabyte 1 33.35 33.31 0.12% :white_check_mark:
dlrm-criteoterabyte_fp16 1 52.72 52.71 0.02% :white_check_mark:
agentmodel 1 8,340.77 8,235.67 1.28% :white_check_mark:
unet_fp16 2 58.82 58.90 -0.14% :white_check_mark:
resnet50v1_fp16 1 917.54 940.89 -2.48% :white_check_mark:
resnet50v1_int8 1 1,035.41 1,025.93 0.92% :white_check_mark:
bert_base_cased_fp16 64 1,170.28 1,170.88 -0.05% :white_check_mark:
bert_large_uncased_fp16 32 363.71 363.69 0.01% :white_check_mark:
bert_large_fp16 1 198.93 200.14 -0.60% :white_check_mark:
distilgpt2_fp16 16 2,199.74 2,200.77 -0.05% :white_check_mark:
yolov5s 1 534.32 535.15 -0.16% :white_check_mark:
tinyllama 1 43.67 43.41 0.59% :white_check_mark:
vicuna-fastchat 1 171.49 178.09 -3.71% :red_circle:
whisper-tiny-encoder 1 418.42 418.18 0.06% :white_check_mark:
whisper-tiny-decoder 1 430.87 427.58 0.77% :white_check_mark:

This build is not recommended to merge :red_circle:

migraphx-bot commented 3 days ago


     :white_check_mark: bert-mrpc-onnx: PASSED: MIGraphX meets tolerance
     :white_check_mark: bert-mrpc-tf: PASSED: MIGraphX meets tolerance
     :white_check_mark: pytorch-examples-wlang-gru: PASSED: MIGraphX meets tolerance
     :white_check_mark: pytorch-examples-wlang-lstm: PASSED: MIGraphX meets tolerance
     :white_check_mark: torchvision-resnet50_1: PASSED: MIGraphX meets tolerance
     :white_check_mark: cadene-dpn92_1: PASSED: MIGraphX meets tolerance
     :white_check_mark: cadene-resnext101_1: PASSED: MIGraphX meets tolerance
     :white_check_mark: dlrm-criteoterabyte: PASSED: MIGraphX meets tolerance
     :white_check_mark: agentmodel: PASSED: MIGraphX meets tolerance
     :white_check_mark: unet: PASSED: MIGraphX meets tolerance
     :white_check_mark: resnet50v1: PASSED: MIGraphX meets tolerance
     :white_check_mark: bert_base_cased_fp16: PASSED: MIGraphX meets tolerance
:red_circle:bert_large_uncased_fp16: FAILED: MIGraphX is not within tolerance - check verbose output

     :white_check_mark: bert_large: PASSED: MIGraphX meets tolerance
     :white_check_mark: yolov5s: PASSED: MIGraphX meets tolerance
     :white_check_mark: tinyllama: PASSED: MIGraphX meets tolerance
     :white_check_mark: vicuna-fastchat: PASSED: MIGraphX meets tolerance
     :white_check_mark: whisper-tiny-encoder: PASSED: MIGraphX meets tolerance
     :white_check_mark: whisper-tiny-decoder: PASSED: MIGraphX meets tolerance
     :white_check_mark: distilgpt2_fp16: PASSED: MIGraphX meets tolerance