ROCm / AMDMIGraphX

AMD's graph optimization engine.
https://rocm.docs.amd.com/projects/AMDMIGraphX/en/latest/
MIT License
185 stars 86 forks source link

Roialign fix and half_pixel mode support #3482

Open bpickrel opened 1 month ago

bpickrel commented 1 month ago

Fix bugs in the implementation of ROIAlign operation which were found when attempting to run it with the half_pixel coordinate conversion mode, to include more thorough tests. Some bugs are mode-specific and some are not.

The ROIAlign operation was first proposed in a paper at https://arxiv.org/abs/1703.06870v3 which introduced the Mask R-CNN model. It was a variant of the ROIPool operation which was found to give significantly better accuracy. In the implementations in Torch, Onnxruntime, and Migraphx, ROIPool and ROIAlign are implemented in the same op. with different choices for the mode attribute, with output_half_pixel for ROIPool and half_pixel for ROIAlign; thus, there is no ROIAlign op without fixing the half_pixel mode.

Note, by the way, that these same coordinate conversion modes are also attributes of the Resize op.

MIGraphX uses the Onnxruntime implementation of ROIAlign as its functional specification and should give identical results.

This change is prerequisite for torch-migraphx PR #143 but does not close it.

bpickrel commented 1 month ago

Don't review this yet! It's very incomplete and I just created the PR to make it easy to visualize the changes so far.

codecov[bot] commented 1 month ago

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Project coverage is 92.18%. Comparing base (fc26f01) to head (400bd07). Report is 66 commits behind head on develop.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## develop #3482 +/- ## =========================================== + Coverage 92.02% 92.18% +0.16% =========================================== Files 509 513 +4 Lines 21014 21573 +559 =========================================== + Hits 19339 19888 +549 - Misses 1675 1685 +10 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

bpickrel commented 1 month ago

Reviewers: This PR isn't quite ready for review as I'm still working on the GPU implementation. I'm just opening it to activate the Jenkins testing.

bpickrel commented 1 month ago

This PR is finally ready to begin review. Note that I left a "todo" about recalculating the indexing order in the Reference op. If we come up with a way to iterate through the array in correct order the first time through, great. But getting the order right to this point (just to obtain correct test values) has been very time-consuming and if a fix doesn't present itself, it may be better to stick with the workaround.

bpickrel commented 2 weeks ago

The licensing check fail now occurring is for a file not related to this PR:

Error: The licenses for the following 1 file(s) either... do not match the year of commit, have a different copyright format or have not been synced from the latest roialign_fix branch:
['src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp']
bpickrel commented 2 weeks ago

Looks fine, just a few small things. I haven't been able to fully wrap my head around all the math in the ref and gpu impl, the index changes look reasonable. Do we have a way to directly test against ORT (without maunally extracting gold outputs)? If so, I think it would be worthwhile to add a few more tests comparing with ORT

I think it would be possible to add a test following the model of the existing tests in test/py/. With luck it wouldn't be very much extra work, half a day or so. @pfultz2 what do you think? The rationale for adding an op test here is that the ROIAlign op is defined in terms of the Onnxruntime implementation so it makes sense to have a specialized test with ORT as the reference.

Note my recent comment that I learned the ORT implementation of the max pooling option is buggy and can't be used for a test reference until the fix is released. I don't know whether max pooling is widely used with this op or not.

bpickrel commented 2 weeks ago

Looks fine, just a few small things. I haven't been able to fully wrap my head around all the math in the ref and gpu impl, the index changes look reasonable.

Do you want me to go over it with you? I can explain the intent of nearly everything but the indexing is still very difficult to unravel.

bpickrel commented 5 days ago

Requesting re-review after a recent change: Added a Python test test_roialign.py to check MigraphX output directly vs. onnxruntime, and found that MigraphX results were internally consistent but output the right values in a transposed shape. Fixing this caused changes to internal computations, but I updated both the ref. and GPU implementations to emit a corrected shape.

Repeat of an earlier comment: we can't do a similar check vs. onnxruntime for "max" pooling mode because the ORT implementation of max pooling in ROIAlign has a known bug.

migraphx-bot commented 4 days ago
Test Batch Rate new
400bd0
Rate old
c51bea
Diff Compare
torchvision-resnet50 64 3,258.94 3,257.81 0.03% :white_check_mark:
torchvision-resnet50_fp16 64 6,988.19 6,987.81 0.01% :white_check_mark:
torchvision-densenet121 32 2,431.87 2,434.57 -0.11% :white_check_mark:
torchvision-densenet121_fp16 32 4,099.62 4,065.61 0.84% :white_check_mark:
torchvision-inceptionv3 32 1,636.68 1,637.17 -0.03% :white_check_mark:
torchvision-inceptionv3_fp16 32 2,761.86 2,759.26 0.09% :white_check_mark:
cadene-inceptionv4 16 775.66 776.31 -0.08% :white_check_mark:
cadene-resnext64x4 16 808.05 811.75 -0.46% :white_check_mark:
slim-mobilenet 64 7,525.92 7,533.16 -0.10% :white_check_mark:
slim-nasnetalarge 64 211.28 211.39 -0.05% :white_check_mark:
slim-resnet50v2 64 3,497.50 3,504.83 -0.21% :white_check_mark:
bert-mrpc-onnx 8 1,147.54 1,146.47 0.09% :white_check_mark:
bert-mrpc-tf 1 464.53 473.89 -1.98% :white_check_mark:
pytorch-examples-wlang-gru 1 413.15 425.31 -2.86% :white_check_mark:
pytorch-examples-wlang-lstm 1 389.69 408.68 -4.65% :red_circle:
torchvision-resnet50_1 1 806.71 771.75 4.53% :high_brightness:
cadene-dpn92_1 1 399.87 399.01 0.22% :white_check_mark:
cadene-resnext101_1 1 382.86 383.85 -0.26% :white_check_mark:
onnx-taau-downsample 1 343.04 343.09 -0.02% :white_check_mark:
dlrm-criteoterabyte 1 33.33 33.31 0.05% :white_check_mark:
dlrm-criteoterabyte_fp16 1 52.71 52.71 0.01% :white_check_mark:
agentmodel 1 7,901.33 8,235.67 -4.06% :red_circle:
unet_fp16 2 58.79 58.90 -0.19% :white_check_mark:
resnet50v1_fp16 1 948.60 940.89 0.82% :white_check_mark:
resnet50v1_int8 1 1,002.37 1,025.93 -2.30% :white_check_mark:
bert_base_cased_fp16 64 1,171.54 1,170.88 0.06% :white_check_mark:
bert_large_uncased_fp16 32 363.60 363.69 -0.02% :white_check_mark:
bert_large_fp16 1 200.49 200.14 0.18% :white_check_mark:
distilgpt2_fp16 16 2,202.57 2,200.77 0.08% :white_check_mark:
yolov5s 1 543.48 535.15 1.56% :white_check_mark:
tinyllama 1 43.46 43.41 0.10% :white_check_mark:
vicuna-fastchat 1 175.77 178.09 -1.30% :white_check_mark:
whisper-tiny-encoder 1 417.88 418.18 -0.07% :white_check_mark:
whisper-tiny-decoder 1 427.73 427.58 0.03% :white_check_mark:

This build is not recommended to merge :red_circle:

migraphx-bot commented 4 days ago


     :white_check_mark: bert-mrpc-onnx: PASSED: MIGraphX meets tolerance
     :white_check_mark: bert-mrpc-tf: PASSED: MIGraphX meets tolerance
     :white_check_mark: pytorch-examples-wlang-gru: PASSED: MIGraphX meets tolerance
     :white_check_mark: pytorch-examples-wlang-lstm: PASSED: MIGraphX meets tolerance
     :white_check_mark: torchvision-resnet50_1: PASSED: MIGraphX meets tolerance
     :white_check_mark: cadene-dpn92_1: PASSED: MIGraphX meets tolerance
     :white_check_mark: cadene-resnext101_1: PASSED: MIGraphX meets tolerance
     :white_check_mark: dlrm-criteoterabyte: PASSED: MIGraphX meets tolerance
     :white_check_mark: agentmodel: PASSED: MIGraphX meets tolerance
     :white_check_mark: unet: PASSED: MIGraphX meets tolerance
     :white_check_mark: resnet50v1: PASSED: MIGraphX meets tolerance
     :white_check_mark: bert_base_cased_fp16: PASSED: MIGraphX meets tolerance
:red_circle:bert_large_uncased_fp16: FAILED: MIGraphX is not within tolerance - check verbose output

     :white_check_mark: bert_large: PASSED: MIGraphX meets tolerance
     :white_check_mark: yolov5s: PASSED: MIGraphX meets tolerance
     :white_check_mark: tinyllama: PASSED: MIGraphX meets tolerance
     :white_check_mark: vicuna-fastchat: PASSED: MIGraphX meets tolerance
     :white_check_mark: whisper-tiny-encoder: PASSED: MIGraphX meets tolerance
     :white_check_mark: whisper-tiny-decoder: PASSED: MIGraphX meets tolerance
     :white_check_mark: distilgpt2_fp16: PASSED: MIGraphX meets tolerance

pfultz2 commented 4 days ago

You should capture the onnxruntime results and just create a ref test.