migraphx-benchmark / AMDMIGraphX

AMD's graph optimization engine.
https://rocmsoftwareplatform.github.io/AMDMIGraphX/doc/html/
MIT License
0 stars 1 forks source link

ConvTranspose2d inaccuracies #135

Open attila-dusnoki-htec opened 1 year ago

attila-dusnoki-htec commented 1 year ago

Failing tests:

attila-dusnoki-htec commented 1 year ago
FAIL: test_ConvTranspose2d_cpu (__main__.OnnxBackendPyTorchConvertedModelTest) ``` ====================================================================== FAIL: test_ConvTranspose2d_cpu (__main__.OnnxBackendPyTorchConvertedModelTest) ---------------------------------------------------------------------- Traceback (most recent call last): File "/usr/local/lib/python3.8/dist-packages/onnx/backend/test/runner/__init__.py", line 290, in device_test_func return test_func(*args, device=device, **kwargs) File "/usr/local/lib/python3.8/dist-packages/onnx/backend/test/runner/__init__.py", line 467, in run self.assert_similar_outputs( File "../test/py/onnx_backend_test.py", line 59, in assert_similar_outputs np.testing.assert_allclose(ref_outputs[i], File "/usr/local/lib/python3.8/dist-packages/numpy/testing/_private/utils.py", line 1530, in assert_allclose assert_array_compare(compare, actual, desired, err_msg=str(err_msg), File "/usr/local/lib/python3.8/dist-packages/numpy/testing/_private/utils.py", line 844, in assert_array_compare raise AssertionError(msg) AssertionError: Not equal to tolerance rtol=0.001, atol=1e-05 Program = module: "main" 0 = @param:0 -> float_type, {1, 3, 7, 6}, {126, 42, 6, 1}, target_id=0 @1 = @literal{0.156559, -0.0778541, 0.00345176, -0.162995} -> float_type, {4}, {1}, target_id=0 @2 = @literal{ ... } -> float_type, {3, 4, 3, 3}, {36, 9, 3, 1}, target_id=0 @3 = convolution_backwards[padding={1, 1},stride={3, 2},dilation={1, 1},padding_mode=0,group=1](0,@2) -> float_type, {1, 4, 19, 11}, {836, 209, 11, 1}, target_id=0 @4 = pad[mode=0,pads={0, 0, 0, 0, 0, 0, 1, 1},value=0](@3) -> float_type, {1, 4, 20, 12}, {960, 240, 12, 1}, target_id=0 @5 = broadcast[axis=1,out_lens={1, 4, 20, 12}](@1) -> float_type, {1, 4, 20, 12}, {0, 1, 0, 0}, target_id=0 @6 = add(@4,@5) -> float_type, {1, 4, 20, 12}, {960, 240, 12, 1}, target_id=0 @7 = @return(@6), target_id=0 Compiled program = module: "main" @0 = check_context::migraphx::gpu::context -> float_type, {}, {}, target_id=0 @1 = hip::hip_allocate_memory[shape=int8_type, {7680}, {1},id=main:scratch] -> int8_type, {7680}, {1}, target_id=0 @2 = hip::hip_copy_literal[id=main:@literal:0] -> float_type, {4}, {1}, target_id=0 @3 = hip::hip_copy_literal[id=main:@literal:1] -> float_type, {3, 4, 3, 3}, {36, 9, 3, 1}, target_id=0 0 = @param:0 -> float_type, {1, 3, 7, 6}, {126, 42, 6, 1}, target_id=0 @5 = load[offset=0,end=504](@1) -> float_type, {1, 3, 7, 6}, {126, 42, 6, 1}, target_id=0 @6 = hip::copy_to_gpu(0,@5) -> float_type, {1, 3, 7, 6}, {126, 42, 6, 1}, target_id=0 @7 = load[offset=3840,end=7184](@1) -> float_type, {1, 4, 19, 11}, {836, 209, 11, 1}, target_id=0 @8 = load[offset=0,end=0](@1) -> int8_type, {0}, {1}, target_id=0 @9 = gpu::convolution_backwards[op={padding={1, 1},stride={3, 2},dilation={1, 1},padding_mode=0,group=1},solution_object={binary_object: 423},algo=0,int8_x4_format=0,solution_id=0](@6,@3,@8,@7) -> float_type, {1, 4, 19, 11}, {836, 209, 11, 1}, target_id=0 @10 = load[offset=0,end=3840](@1) -> float_type, {1, 4, 20, 12}, {960, 240, 12, 1}, target_id=0 @11 = gpu::code_object[code_object=9528,symbol_name=pad_kernel,global=1024,local=1024,](@9,@10) -> float_type, {1, 4, 20, 12}, {960, 240, 12, 1}, target_id=0 @12 = load[offset=3840,end=7680](@1) -> float_type, {1, 4, 20, 12}, {960, 240, 12, 1}, target_id=0 @13 = broadcast[axis=1,out_lens={1, 4, 20, 12}](@2) -> float_type, {1, 4, 20, 12}, {0, 1, 0, 0}, target_id=0 @14 = gpu::code_object[code_object=9528,symbol_name=add_kernel,global=1024,local=1024,](@11,@13,@12) -> float_type, {1, 4, 20, 12}, {960, 240, 12, 1}, target_id=0 @15 = hip::copy_from_gpu(@14) -> float_type, {1, 4, 20, 12}, {960, 240, 12, 1}, target_id=0 @16 = hip::sync_stream(@15) -> float_type, {1, 4, 20, 12}, {960, 240, 12, 1}, target_id=0 @17 = @return(@16), target_id=0 Mismatched elements: 124 / 960 (12.9%) Max absolute difference: 0.43139315 Max relative difference: 106.346725 x: array([[[[-3.870082e-02, 4.058291e-01, 9.855538e-02, -3.768350e-01, -1.542787e-02, 6.494146e-01, 4.276419e-01, -6.573269e-01, 3.279429e-01, -1.139378e-01, 1.777776e-01, 2.557690e-02],... y: array([[[[-3.870082e-02, 4.058291e-01, 9.855538e-02, -3.768350e-01, -1.542789e-02, 6.494146e-01, 4.276419e-01, -6.573269e-01, 3.279429e-01, -1.139379e-01, 1.777776e-01, 1.565587e-01],... ```
FAIL: test_ConvTranspose2d_no_bias_cpu (__main__.OnnxBackendPyTorchConvertedModelTest) ``` ====================================================================== FAIL: test_ConvTranspose2d_no_bias_cpu (__main__.OnnxBackendPyTorchConvertedModelTest) ---------------------------------------------------------------------- Traceback (most recent call last): File "/usr/local/lib/python3.8/dist-packages/onnx/backend/test/runner/__init__.py", line 290, in device_test_func return test_func(*args, device=device, **kwargs) File "/usr/local/lib/python3.8/dist-packages/onnx/backend/test/runner/__init__.py", line 467, in run self.assert_similar_outputs( File "../test/py/onnx_backend_test.py", line 59, in assert_similar_outputs np.testing.assert_allclose(ref_outputs[i], File "/usr/local/lib/python3.8/dist-packages/numpy/testing/_private/utils.py", line 1530, in assert_allclose assert_array_compare(compare, actual, desired, err_msg=str(err_msg), File "/usr/local/lib/python3.8/dist-packages/numpy/testing/_private/utils.py", line 844, in assert_array_compare raise AssertionError(msg) AssertionError: Not equal to tolerance rtol=0.001, atol=1e-05 Program = module: "main" 0 = @param:0 -> float_type, {1, 3, 6, 7}, {126, 42, 7, 1}, target_id=0 @1 = @literal{ ... } -> float_type, {3, 4, 3, 3}, {36, 9, 3, 1}, target_id=0 @2 = convolution_backwards[padding={1, 1},stride={2, 3},dilation={1, 1},padding_mode=0,group=1](0,@1) -> float_type, {1, 4, 11, 19}, {836, 209, 19, 1}, target_id=0 @3 = pad[mode=0,pads={0, 0, 0, 0, 0, 0, 1, 1},value=0](@2) -> float_type, {1, 4, 12, 20}, {960, 240, 20, 1}, target_id=0 @4 = @return(@3), target_id=0 Compiled program = module: "main" @0 = check_context::migraphx::gpu::context -> float_type, {}, {}, target_id=0 @1 = hip::hip_allocate_memory[shape=int8_type, {7184}, {1},id=main:scratch] -> int8_type, {7184}, {1}, target_id=0 @2 = load[offset=3344,end=3848](@1) -> float_type, {1, 3, 6, 7}, {126, 42, 7, 1}, target_id=0 0 = @param:0 -> float_type, {1, 3, 6, 7}, {126, 42, 7, 1}, target_id=0 @4 = hip::copy_to_gpu(0,@2) -> float_type, {1, 3, 6, 7}, {126, 42, 7, 1}, target_id=0 @5 = hip::hip_copy_literal[id=main:@literal:0] -> float_type, {3, 4, 3, 3}, {36, 9, 3, 1}, target_id=0 @6 = load[offset=0,end=0](@1) -> int8_type, {0}, {1}, target_id=0 @7 = load[offset=0,end=3344](@1) -> float_type, {1, 4, 11, 19}, {836, 209, 19, 1}, target_id=0 @8 = gpu::convolution_backwards[op={padding={1, 1},stride={2, 3},dilation={1, 1},padding_mode=0,group=1},solution_object={binary_object: 423},algo=0,int8_x4_format=0,solution_id=0](@4,@5,@6,@7) -> float_type, {1, 4, 11, 19}, {836, 209, 19, 1}, target_id=0 @9 = load[offset=3344,end=7184](@1) -> float_type, {1, 4, 12, 20}, {960, 240, 20, 1}, target_id=0 @10 = gpu::code_object[code_object=9528,symbol_name=pad_kernel,global=1024,local=1024,](@8,@9) -> float_type, {1, 4, 12, 20}, {960, 240, 20, 1}, target_id=0 @11 = hip::copy_from_gpu(@10) -> float_type, {1, 4, 12, 20}, {960, 240, 20, 1}, target_id=0 @12 = hip::sync_stream(@11) -> float_type, {1, 4, 12, 20}, {960, 240, 20, 1}, target_id=0 @13 = @return(@12), target_id=0 Mismatched elements: 124 / 960 (12.9%) Max absolute difference: 0.81626 Max relative difference: 1.4067205e-05 x: array([[[[ 4.195647e-01, -2.791808e-01, -4.167238e-01, -2.238854e-01, 1.521523e-01, -2.762069e-01, -2.465742e-01, 1.158977e-01, -5.013817e-01, -2.075676e-01, 1.123053e-02, 3.795193e-01,... y: array([[[[ 4.195647e-01, -2.791808e-01, -4.167239e-01, -2.238854e-01, 1.521523e-01, -2.762069e-01, -2.465741e-01, 1.158977e-01, -5.013817e-01, -2.075676e-01, 1.123052e-02, 3.795193e-01,... ```
attila-dusnoki-htec commented 11 months ago

Image

output_padding is a separate pad op after the calculation. In the above fails, this means zeros instead of calculated values, and the bias (0.156559 or 1.565587e-01) added at the end.

 x: array([[[[-3.870082e-02,  4.058291e-01,  9.855538e-02, -3.768350e-01,
          -1.542787e-02,  6.494146e-01,  4.276419e-01, -6.573269e-01,
           3.279429e-01, -1.139378e-01,  1.777776e-01,  !!2.557690e-02!!],...
 y: array([[[[-3.870082e-02,  4.058291e-01,  9.855538e-02, -3.768350e-01,
          -1.542789e-02,  6.494146e-01,  4.276419e-01, -6.573269e-01,
           3.279429e-01, -1.139379e-01,  1.777776e-01,  !!1.565587e-01!!],...

Note: there is no other test that use output_padding attribute

attila-dusnoki-htec commented 11 months ago

Possibly related issue: https://github.com/ROCmSoftwarePlatform/AMDMIGraphX/issues/1868