microsoft / onnxscript

ONNX Script enables developers to naturally author ONNX functions and models using a subset of Python.
https://onnxscript.ai/
MIT License
278 stars 53 forks source link

[torchlib] trunc_div float16 off-by-one mismatches #990

Open justinchuby opened 1 year ago

justinchuby commented 1 year ago

Summary

The output of ONNX Runtime does not match that of PyTorch when executing test ops_test.TestOutputConsistencyFullGraphCPU.test_output_match_opinfo__div_mode_trunc_rounding_cpu_float16, sample 5 in ONNX Script TorchLib.

To recreate this report, use

CREATE_REPRODUCTION_REPORT=1 python -m pytest onnxscript/tests/function_libs/torch_lib/ops_test.py -k test_output_match_opinfo__div_mode_trunc_rounding_cpu_float16

Inputs

Shapes: ['Tensor<torch.Size([5, 10, 5]), dtype=torch.float16>', 'Tensor<torch.Size([5, 10, 5]), dtype=torch.float16>']

inputs = (tensor([[[-6.6953,  6.7070, -7.7422, -5.9688, -3.5234],
         [-0.0703,  8.8828,  5.0195, -1.2656, -8.9922],
         [-1.1250,  2.6641, -6.7773,  5.7227, -5.0273],
         [ 0.3164,  1.3975,  8.9453, -0.0879, -8.9297],
         [ 2.0391,  0.8613,  5.7812, -1.8369, -7.1797],
         [ 1.1426, -5.4570, -7.6719, -8.5312,  1.0459],
         [ 5.9062,  8.0781, -3.4883, -3.8496,  5.4844],
         [ 8.5000,  3.6836,  3.6992, -8.0938,  7.4805],
         [-6.7773,  1.9863,  0.9756, -4.5273,  4.1484],
         [-8.1406, -7.6641,  7.5586, -0.9756,  7.8672]],

        [[-1.2920, -6.0391,  6.1953, -1.7842,  1.4238],
         [ 6.3203, -0.5977,  5.9766,  8.7422, -7.5938],
         [ 5.0898, -1.9688, -5.5117,  2.4258, -1.8369],
         [ 8.0391,  4.4219,  7.8672,  3.4375,  1.1777],
         [ 1.7051, -5.4844, -3.3828,  0.2812, -2.9609],
         [-4.9648, -0.6152,  7.7500, -4.8789,  3.2871],
         [ 0.1846, -2.4180,  5.8789, -8.6719, -6.9883],
         [ 3.1992, -4.0625, -5.1602,  6.8125,  2.4785],
         [-8.7734, -7.0234,  2.4258, -1.9951,  5.3438],
         [-6.5117, -1.2832, -0.5713,  4.2188, -5.0273]],

        [[-0.0439,  7.2773, -1.8369, -3.2168,  4.5273],
         [-4.2031,  1.6787,  0.4219,  4.4922, -8.8594],
         [ 1.5029,  1.9248, -7.6211, -4.8164,  0.3955],
         [ 3.2695, -4.4570, -8.1406,  7.8320,  6.5391],
         [ 0.3164, -6.5469,  3.0586, -4.6406, -6.3555],
         [-7.4180, -3.6738,  0.8613, -2.8555, -0.2812],
         [ 1.0107, -7.3281, -6.7852, -0.3867, -0.8525],
         [ 8.6328,  1.4062,  2.2422,  2.8301, -7.8828],
         [-8.8516,  6.9609,  8.7969,  6.4531, -4.0352],
         [-4.0000,  2.8301,  7.6562, -2.9805,  6.9531]],

        [[-3.5078,  2.5234,  8.2812, -1.5029,  5.8203],
         [-0.5537,  5.6094, -7.6484, -4.7031,  3.3828],
         [ 8.6562, -2.1094,  0.9053,  8.1562,  3.4453],
         [ 7.1172, -6.8477,  1.5381,  0.3340,  4.1836],
         [ 5.9844, -2.3379, -4.2812,  7.6094,  7.1797],
         [-0.9141, -5.4219, -4.3945, -1.7314,  4.4375],
         [-1.0723,  7.2422, -8.6953, -2.9883, -8.8359],
         [-6.1875,  6.8125, -4.5078,  8.8359, -5.4922],
         [-3.2617, -0.4658,  7.7500, -4.8672,  1.2480],
         [ 0.7998,  3.1016, -1.6787,  7.0234, -2.6191]],

        [[ 8.4219, -1.5029, -8.5625,  5.8359,  6.1250],
         [-3.9551,  1.6611,  1.0898, -7.6914, -6.2500],
         [-4.0000,  2.9961, -6.3281,  6.0742, -3.0156],
         [-1.6436,  6.8750,  0.4658, -0.1934, -5.1055],
         [-3.6211, -0.2812, -6.6875,  0.4570,  5.0195],
         [-0.1846,  5.4414, -8.9688,  0.8965,  1.7139],
         [-2.2070, -1.3008, -0.2900, -1.2393,  3.0664],
         [-8.5938, -1.3887, -1.7754, -0.7822,  7.1992],
         [-3.2422, -6.6445, -5.7578, -3.4180,  0.9229],
         [ 5.8789,  5.7812, -3.2344, -7.9531, -1.7051]]], dtype=torch.float16), tensor([[[ 6.9062,  5.3008,  0.5010,  8.3516,  3.0312],
         [-2.4961,  4.1562, -7.9805,  7.2852, -2.0566],
         [ 1.1074,  3.4453,  7.9727, -0.1055,  5.6250],
         [ 0.5010, -3.4375, -4.5000,  6.5938, -1.9600],
         [ 0.8350,  5.5977,  1.0283, -4.0859,  5.2734],
         [-4.7188,  2.0742,  7.7695, -1.1250,  2.3828],
         [ 1.6523,  8.7734, -2.8203,  3.9199, -3.9023],
         [-2.0664, -1.4414, -6.8359,  8.4531,  0.6592],
         [-7.6562,  5.7578,  3.8145, -0.2461,  0.7471],
         [-2.8828,  2.4961,  8.5547,  1.9248,  0.5977]],

        [[ 0.4043, -0.1406, -8.1797,  6.5820, -0.5889],
         [-0.1143, -5.0273,  1.7842, -0.9932, -7.0938],
         [-4.8438, -5.2109, -6.0117, -4.1641,  5.3711],
         [ 2.5312,  2.2227,  6.4414, -5.8711, -8.8047],
         [ 6.6875, -4.4219,  6.6523,  7.3477,  0.8701],
         [-1.9600,  1.5732,  1.9512,  7.9531, -2.3125],
         [ 5.7812, -5.1250,  5.3359, -2.2500, -0.0791],
         [ 2.8477,  8.8281, -4.4727,  0.1318, -2.0469],
         [-6.0547,  6.6641, -6.7422, -4.6406,  8.6719],
         [-1.1338, -7.6016,  4.1641, -1.5205,  5.8281]],

        [[ 7.9727, -2.9453,  5.9219, -5.6680,  4.7812],
         [ 8.1562, -0.5977, -2.9355,  3.7441,  5.3789],
         [ 2.3477,  7.3281,  5.0625,  3.6562,  1.4678],
         [ 7.3906,  0.3691,  4.8789, -6.5039, -5.5195],
         [-8.9062,  5.3281,  2.6445,  7.8828,  2.2148],
         [-4.6250, -2.3828, -8.7031, -2.6191, -1.5469],
         [-8.2188, -2.1875, -3.6914,  8.0938, -0.4834],
         [ 5.2812, -0.0264, -5.3359, -8.0000, -3.5859],
         [ 7.9297,  1.7139, -1.7490,  4.3945,  5.4922],
         [ 4.3242,  0.6504, -5.3789,  3.2969,  3.6836]],

        [[-2.6016,  2.4258, -0.0352, -2.3203, -3.0664],
         [ 3.4531,  7.6289, -4.9922, -0.9229, -7.3203],
         [-2.5137,  6.0469,  0.8613,  8.8516,  4.8086],
         [-0.4131, -0.1406,  8.7734, -8.1641,  5.8281],
         [-3.1816,  1.6611,  6.4609, -8.6875,  6.8477],
         [ 3.2344,  6.1016,  4.6055, -6.5547, -7.1016],
         [-1.3887, -1.3359,  7.0039,  2.3906,  7.7344],
         [-8.4609, -3.7695,  7.7266, -2.6016,  2.9961],
         [-5.5195, -2.8652,  0.1582, -7.4531,  5.3711],
         [-5.5273, -5.0977,  2.0391,  8.0312, -8.3438]],

        [[-6.6797, -5.0078,  6.9883, -2.6992, -2.5488],
         [ 3.3672,  4.9492, -1.5293,  4.4375,  3.2168],
         [-7.7422, -1.2305,  0.5977,  0.8613, -5.5273],
         [ 0.2109, -0.6768, -1.1777,  4.1133,  4.2461],
         [ 4.9570,  1.3184,  1.2568, -4.0078,  0.3340],
         [-8.0469,  6.0820,  0.3604,  1.6260, -0.7910],
         [-5.5117,  1.9512, -3.4375,  1.1602,  1.4238],
         [ 1.8809, -2.5664, -6.9453,  7.3984, -3.2266],
         [-2.4082, -3.8398,  0.0703,  0.6416, -0.6240],
         [-4.6836, -0.9844,  2.2148,  2.0117,  4.5547]]], dtype=torch.float16))
kwargs = {'rounding_mode': 'trunc'}

Expected output

expected = tensor([[[  -0.,    1.,  -15.,   -0.,   -1.],
         [   0.,    2.,   -0.,   -0.,    4.],
         [  -1.,    0.,   -0.,  -54.,   -0.],
         [   0.,   -0.,   -1.,   -0.,    4.],
         [   2.,    0.,    5.,    0.,   -1.],
         [  -0.,   -2.,   -0.,    7.,    0.],
         [   3.,    0.,    1.,   -0.,   -1.],
         [  -4.,   -2.,   -0.,   -0.,   11.],
         [   0.,    0.,    0.,   18.,    5.],
         [   2.,   -3.,    0.,   -0.,   13.]],

        [[  -3.,   42.,   -0.,   -0.,   -2.],
         [ -55.,    0.,    3.,   -8.,    1.],
         [  -1.,    0.,    0.,   -0.,   -0.],
         [   3.,    1.,    1.,   -0.,   -0.],
         [   0.,    1.,   -0.,    0.,   -3.],
         [   2.,   -0.,    3.,   -0.,   -1.],
         [   0.,    0.,    1.,    3.,   88.],
         [   1.,   -0.,    1.,   51.,   -1.],
         [   1.,   -1.,   -0.,    0.,    0.],
         [   5.,    0.,   -0.,   -2.,   -0.]],

        [[  -0.,   -2.,   -0.,    0.,    0.],
         [  -0.,   -2.,   -0.,    1.,   -1.],
         [   0.,    0.,   -1.,   -1.,    0.],
         [   0.,  -12.,   -1.,   -1.,   -1.],
         [  -0.,   -1.,    1.,   -0.,   -2.],
         [   1.,    1.,   -0.,    1.,    0.],
         [  -0.,    3.,    1.,   -0.,    1.],
         [   1.,  -53.,   -0.,   -0.,    2.],
         [  -1.,    4.,   -5.,    1.,   -0.],
         [  -0.,    4.,   -1.,   -0.,    1.]],

        [[   1.,    1., -235.,    0.,   -1.],
         [  -0.,    0.,    1.,    5.,   -0.],
         [  -3.,   -0.,    1.,    0.,    0.],
         [ -17.,   48.,    0.,   -0.,    0.],
         [  -1.,   -1.,   -0.,   -0.,    1.],
         [  -0.,   -0.,   -0.,    0.,   -0.],
         [   0.,   -5.,   -1.,   -1.,   -1.],
         [   0.,   -1.,   -0.,   -3.,   -1.],
         [   0.,    0.,   49.,    0.,    0.],
         [  -0.,   -0.,   -0.,    0.,    0.]],

        [[  -1.,    0.,   -1.,   -2.,   -2.],
         [  -1.,    0.,   -0.,   -1.,   -1.],
         [   0.,   -2.,  -10.,    7.,    0.],
         [  -7.,  -10.,   -0.,   -0.,   -1.],
         [  -0.,   -0.,   -5.,   -0.,   15.],
         [   0.,    0.,  -24.,    0.,   -2.],
         [   0.,   -0.,    0.,   -1.,    2.],
         [  -4.,    0.,    0.,   -0.,   -2.],
         [   1.,    1.,  -81.,   -5.,   -1.],
         [  -1.,   -5.,   -1.,   -3.,   -0.]]], dtype=torch.float16)

Shape: torch.Size([5, 10, 5])

Actual output

actual = tensor([[[   0.,    1.,  -15.,    0.,   -1.],
         [   0.,    2.,    0.,    0.,    4.],
         [  -1.,    0.,    0.,  -54.,    0.],
         [   0.,    0.,   -1.,    0.,    4.],
         [   2.,    0.,    5.,    0.,   -1.],
         [   0.,   -2.,    0.,    7.,    0.],
         [   3.,    0.,    1.,    0.,   -1.],
         [  -4.,   -2.,    0.,    0.,   11.],
         [   0.,    0.,    0.,   18.,    5.],
         [   2.,   -3.,    0.,    0.,   13.]],

        [[  -3.,   42.,    0.,    0.,   -2.],
         [ -55.,    0.,    3.,   -8.,    1.],
         [  -1.,    0.,    0.,    0.,    0.],
         [   3.,    1.,    1.,    0.,    0.],
         [   0.,    1.,    0.,    0.,   -3.],
         [   2.,    0.,    3.,    0.,   -1.],
         [   0.,    0.,    1.,    3.,   88.],
         [   1.,    0.,    1.,   51.,   -1.],
         [   1.,   -1.,    0.,    0.,    0.],
         [   5.,    0.,    0.,   -2.,    0.]],

        [[   0.,   -2.,    0.,    0.,    0.],
         [   0.,   -2.,    0.,    1.,   -1.],
         [   0.,    0.,   -1.,   -1.,    0.],
         [   0.,  -12.,   -1.,   -1.,   -1.],
         [   0.,   -1.,    1.,    0.,   -2.],
         [   1.,    1.,    0.,    1.,    0.],
         [   0.,    3.,    1.,    0.,    1.],
         [   1.,  -53.,    0.,    0.,    2.],
         [  -1.,    4.,   -5.,    1.,    0.],
         [   0.,    4.,   -1.,    0.,    1.]],

        [[   1.,    1., -235.,    0.,   -1.],
         [   0.,    0.,    1.,    5.,    0.],
         [  -3.,    0.,    1.,    0.,    0.],
         [ -17.,   48.,    0.,    0.,    0.],
         [  -1.,   -1.,    0.,    0.,    1.],
         [   0.,    0.,    0.,    0.,    0.],
         [   0.,   -5.,   -1.,   -1.,   -1.],
         [   0.,   -1.,    0.,   -3.,   -1.],
         [   0.,    0.,   48.,    0.,    0.],
         [   0.,    0.,    0.,    0.,    0.]],

        [[  -1.,    0.,   -1.,   -2.,   -2.],
         [  -1.,    0.,    0.,   -1.,   -1.],
         [   0.,   -2.,  -10.,    7.,    0.],
         [  -7.,  -10.,    0.,    0.,   -1.],
         [   0.,    0.,   -5.,    0.,   15.],
         [   0.,    0.,  -24.,    0.,   -2.],
         [   0.,    0.,    0.,   -1.,    2.],
         [  -4.,    0.,    0.,    0.,   -2.],
         [   1.,    1.,  -81.,   -5.,   -1.],
         [  -1.,   -5.,   -1.,   -3.,    0.]]], dtype=torch.float16)

Shape: torch.Size([5, 10, 5])

Difference

--- actual
+++ expected
@@ -1,54 +1,54 @@
-tensor([[[   0.,    1.,  -15.,    0.,   -1.],
-         [   0.,    2.,    0.,    0.,    4.],
-         [  -1.,    0.,    0.,  -54.,    0.],
-         [   0.,    0.,   -1.,    0.,    4.],
+tensor([[[  -0.,    1.,  -15.,   -0.,   -1.],
+         [   0.,    2.,   -0.,   -0.,    4.],
+         [  -1.,    0.,   -0.,  -54.,   -0.],
+         [   0.,   -0.,   -1.,   -0.,    4.],
          [   2.,    0.,    5.,    0.,   -1.],
-         [   0.,   -2.,    0.,    7.,    0.],
-         [   3.,    0.,    1.,    0.,   -1.],
-         [  -4.,   -2.,    0.,    0.,   11.],
+         [  -0.,   -2.,   -0.,    7.,    0.],
+         [   3.,    0.,    1.,   -0.,   -1.],
+         [  -4.,   -2.,   -0.,   -0.,   11.],
          [   0.,    0.,    0.,   18.,    5.],
-         [   2.,   -3.,    0.,    0.,   13.]],
+         [   2.,   -3.,    0.,   -0.,   13.]],

-        [[  -3.,   42.,    0.,    0.,   -2.],
+        [[  -3.,   42.,   -0.,   -0.,   -2.],
          [ -55.,    0.,    3.,   -8.,    1.],
-         [  -1.,    0.,    0.,    0.,    0.],
-         [   3.,    1.,    1.,    0.,    0.],
-         [   0.,    1.,    0.,    0.,   -3.],
-         [   2.,    0.,    3.,    0.,   -1.],
+         [  -1.,    0.,    0.,   -0.,   -0.],
+         [   3.,    1.,    1.,   -0.,   -0.],
+         [   0.,    1.,   -0.,    0.,   -3.],
+         [   2.,   -0.,    3.,   -0.,   -1.],
          [   0.,    0.,    1.,    3.,   88.],
-         [   1.,    0.,    1.,   51.,   -1.],
-         [   1.,   -1.,    0.,    0.,    0.],
-         [   5.,    0.,    0.,   -2.,    0.]],
+         [   1.,   -0.,    1.,   51.,   -1.],
+         [   1.,   -1.,   -0.,    0.,    0.],
+         [   5.,    0.,   -0.,   -2.,   -0.]],

-        [[   0.,   -2.,    0.,    0.,    0.],
-         [   0.,   -2.,    0.,    1.,   -1.],
+        [[  -0.,   -2.,   -0.,    0.,    0.],
+         [  -0.,   -2.,   -0.,    1.,   -1.],
          [   0.,    0.,   -1.,   -1.,    0.],
          [   0.,  -12.,   -1.,   -1.,   -1.],
-         [   0.,   -1.,    1.,    0.,   -2.],
-         [   1.,    1.,    0.,    1.,    0.],
-         [   0.,    3.,    1.,    0.,    1.],
-         [   1.,  -53.,    0.,    0.,    2.],
-         [  -1.,    4.,   -5.,    1.,    0.],
-         [   0.,    4.,   -1.,    0.,    1.]],
+         [  -0.,   -1.,    1.,   -0.,   -2.],
+         [   1.,    1.,   -0.,    1.,    0.],
+         [  -0.,    3.,    1.,   -0.,    1.],
+         [   1.,  -53.,   -0.,   -0.,    2.],
+         [  -1.,    4.,   -5.,    1.,   -0.],
+         [  -0.,    4.,   -1.,   -0.,    1.]],

         [[   1.,    1., -235.,    0.,   -1.],
-         [   0.,    0.,    1.,    5.,    0.],
-         [  -3.,    0.,    1.,    0.,    0.],
-         [ -17.,   48.,    0.,    0.,    0.],
-         [  -1.,   -1.,    0.,    0.,    1.],
-         [   0.,    0.,    0.,    0.,    0.],
+         [  -0.,    0.,    1.,    5.,   -0.],
+         [  -3.,   -0.,    1.,    0.,    0.],
+         [ -17.,   48.,    0.,   -0.,    0.],
+         [  -1.,   -1.,   -0.,   -0.,    1.],
+         [  -0.,   -0.,   -0.,    0.,   -0.],
          [   0.,   -5.,   -1.,   -1.,   -1.],
-         [   0.,   -1.,    0.,   -3.,   -1.],
-         [   0.,    0.,   48.,    0.,    0.],
-         [   0.,    0.,    0.,    0.,    0.]],
+         [   0.,   -1.,   -0.,   -3.,   -1.],
+         [   0.,    0.,   49.,    0.,    0.],
+         [  -0.,   -0.,   -0.,    0.,    0.]],

         [[  -1.,    0.,   -1.,   -2.,   -2.],
-         [  -1.,    0.,    0.,   -1.,   -1.],
+         [  -1.,    0.,   -0.,   -1.,   -1.],
          [   0.,   -2.,  -10.,    7.,    0.],
-         [  -7.,  -10.,    0.,    0.,   -1.],
-         [   0.,    0.,   -5.,    0.,   15.],
+         [  -7.,  -10.,   -0.,   -0.,   -1.],
+         [  -0.,   -0.,   -5.,   -0.,   15.],
          [   0.,    0.,  -24.,    0.,   -2.],
-         [   0.,    0.,    0.,   -1.,    2.],
-         [  -4.,    0.,    0.,    0.,   -2.],
+         [   0.,   -0.,    0.,   -1.,    2.],
+         [  -4.,    0.,    0.,   -0.,   -2.],
          [   1.,    1.,  -81.,   -5.,   -1.],
-         [  -1.,   -5.,   -1.,   -3.,    0.]]], dtype=torch.float16)
+         [  -1.,   -5.,   -1.,   -3.,   -0.]]], dtype=torch.float16)

Full error stack

Tensor-likes are not close!

Mismatched elements: 1 / 250 (0.4%)
Greatest absolute difference: 1.0 at index (3, 8, 2) (up to 1e-05 allowed)
Greatest relative difference: 0.0204010009765625 at index (3, 8, 2) (up to 0.001 allowed)
  File "/home/justinchu/dev/onnx-script/onnxscript/tests/function_libs/torch_lib/ops_test.py", line 259, in run_test_output_match
    torch.testing.assert_close(
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/testing/_comparison.py", line 1511, in assert_close
    raise error_metas[0].to_error(msg)
justinchuby commented 1 year ago

Summary

The output of ONNX Runtime does not match that of PyTorch when executing test ops_test.TestOutputConsistencyEagerCPU.test_output_match_opinfo__div_mode_trunc_rounding_cpu_float16, sample 5 in ONNX Script TorchLib.

To recreate this report, use

CREATE_REPRODUCTION_REPORT=1 python -m pytest onnxscript/tests/function_libs/torch_lib/ops_test.py -k test_output_match_opinfo__div_mode_trunc_rounding_cpu_float16

Inputs

Shapes: ['Tensor<torch.Size([5, 10, 5]), dtype=torch.float16>', 'Tensor<torch.Size([5, 10, 5]), dtype=torch.float16>']

inputs = (tensor([[[-6.6953,  6.7070, -7.7422, -5.9688, -3.5234],
         [-0.0703,  8.8828,  5.0195, -1.2656, -8.9922],
         [-1.1250,  2.6641, -6.7773,  5.7227, -5.0273],
         [ 0.3164,  1.3975,  8.9453, -0.0879, -8.9297],
         [ 2.0391,  0.8613,  5.7812, -1.8369, -7.1797],
         [ 1.1426, -5.4570, -7.6719, -8.5312,  1.0459],
         [ 5.9062,  8.0781, -3.4883, -3.8496,  5.4844],
         [ 8.5000,  3.6836,  3.6992, -8.0938,  7.4805],
         [-6.7773,  1.9863,  0.9756, -4.5273,  4.1484],
         [-8.1406, -7.6641,  7.5586, -0.9756,  7.8672]],

        [[-1.2920, -6.0391,  6.1953, -1.7842,  1.4238],
         [ 6.3203, -0.5977,  5.9766,  8.7422, -7.5938],
         [ 5.0898, -1.9688, -5.5117,  2.4258, -1.8369],
         [ 8.0391,  4.4219,  7.8672,  3.4375,  1.1777],
         [ 1.7051, -5.4844, -3.3828,  0.2812, -2.9609],
         [-4.9648, -0.6152,  7.7500, -4.8789,  3.2871],
         [ 0.1846, -2.4180,  5.8789, -8.6719, -6.9883],
         [ 3.1992, -4.0625, -5.1602,  6.8125,  2.4785],
         [-8.7734, -7.0234,  2.4258, -1.9951,  5.3438],
         [-6.5117, -1.2832, -0.5713,  4.2188, -5.0273]],

        [[-0.0439,  7.2773, -1.8369, -3.2168,  4.5273],
         [-4.2031,  1.6787,  0.4219,  4.4922, -8.8594],
         [ 1.5029,  1.9248, -7.6211, -4.8164,  0.3955],
         [ 3.2695, -4.4570, -8.1406,  7.8320,  6.5391],
         [ 0.3164, -6.5469,  3.0586, -4.6406, -6.3555],
         [-7.4180, -3.6738,  0.8613, -2.8555, -0.2812],
         [ 1.0107, -7.3281, -6.7852, -0.3867, -0.8525],
         [ 8.6328,  1.4062,  2.2422,  2.8301, -7.8828],
         [-8.8516,  6.9609,  8.7969,  6.4531, -4.0352],
         [-4.0000,  2.8301,  7.6562, -2.9805,  6.9531]],

        [[-3.5078,  2.5234,  8.2812, -1.5029,  5.8203],
         [-0.5537,  5.6094, -7.6484, -4.7031,  3.3828],
         [ 8.6562, -2.1094,  0.9053,  8.1562,  3.4453],
         [ 7.1172, -6.8477,  1.5381,  0.3340,  4.1836],
         [ 5.9844, -2.3379, -4.2812,  7.6094,  7.1797],
         [-0.9141, -5.4219, -4.3945, -1.7314,  4.4375],
         [-1.0723,  7.2422, -8.6953, -2.9883, -8.8359],
         [-6.1875,  6.8125, -4.5078,  8.8359, -5.4922],
         [-3.2617, -0.4658,  7.7500, -4.8672,  1.2480],
         [ 0.7998,  3.1016, -1.6787,  7.0234, -2.6191]],

        [[ 8.4219, -1.5029, -8.5625,  5.8359,  6.1250],
         [-3.9551,  1.6611,  1.0898, -7.6914, -6.2500],
         [-4.0000,  2.9961, -6.3281,  6.0742, -3.0156],
         [-1.6436,  6.8750,  0.4658, -0.1934, -5.1055],
         [-3.6211, -0.2812, -6.6875,  0.4570,  5.0195],
         [-0.1846,  5.4414, -8.9688,  0.8965,  1.7139],
         [-2.2070, -1.3008, -0.2900, -1.2393,  3.0664],
         [-8.5938, -1.3887, -1.7754, -0.7822,  7.1992],
         [-3.2422, -6.6445, -5.7578, -3.4180,  0.9229],
         [ 5.8789,  5.7812, -3.2344, -7.9531, -1.7051]]], dtype=torch.float16), tensor([[[ 6.9062,  5.3008,  0.5010,  8.3516,  3.0312],
         [-2.4961,  4.1562, -7.9805,  7.2852, -2.0566],
         [ 1.1074,  3.4453,  7.9727, -0.1055,  5.6250],
         [ 0.5010, -3.4375, -4.5000,  6.5938, -1.9600],
         [ 0.8350,  5.5977,  1.0283, -4.0859,  5.2734],
         [-4.7188,  2.0742,  7.7695, -1.1250,  2.3828],
         [ 1.6523,  8.7734, -2.8203,  3.9199, -3.9023],
         [-2.0664, -1.4414, -6.8359,  8.4531,  0.6592],
         [-7.6562,  5.7578,  3.8145, -0.2461,  0.7471],
         [-2.8828,  2.4961,  8.5547,  1.9248,  0.5977]],

        [[ 0.4043, -0.1406, -8.1797,  6.5820, -0.5889],
         [-0.1143, -5.0273,  1.7842, -0.9932, -7.0938],
         [-4.8438, -5.2109, -6.0117, -4.1641,  5.3711],
         [ 2.5312,  2.2227,  6.4414, -5.8711, -8.8047],
         [ 6.6875, -4.4219,  6.6523,  7.3477,  0.8701],
         [-1.9600,  1.5732,  1.9512,  7.9531, -2.3125],
         [ 5.7812, -5.1250,  5.3359, -2.2500, -0.0791],
         [ 2.8477,  8.8281, -4.4727,  0.1318, -2.0469],
         [-6.0547,  6.6641, -6.7422, -4.6406,  8.6719],
         [-1.1338, -7.6016,  4.1641, -1.5205,  5.8281]],

        [[ 7.9727, -2.9453,  5.9219, -5.6680,  4.7812],
         [ 8.1562, -0.5977, -2.9355,  3.7441,  5.3789],
         [ 2.3477,  7.3281,  5.0625,  3.6562,  1.4678],
         [ 7.3906,  0.3691,  4.8789, -6.5039, -5.5195],
         [-8.9062,  5.3281,  2.6445,  7.8828,  2.2148],
         [-4.6250, -2.3828, -8.7031, -2.6191, -1.5469],
         [-8.2188, -2.1875, -3.6914,  8.0938, -0.4834],
         [ 5.2812, -0.0264, -5.3359, -8.0000, -3.5859],
         [ 7.9297,  1.7139, -1.7490,  4.3945,  5.4922],
         [ 4.3242,  0.6504, -5.3789,  3.2969,  3.6836]],

        [[-2.6016,  2.4258, -0.0352, -2.3203, -3.0664],
         [ 3.4531,  7.6289, -4.9922, -0.9229, -7.3203],
         [-2.5137,  6.0469,  0.8613,  8.8516,  4.8086],
         [-0.4131, -0.1406,  8.7734, -8.1641,  5.8281],
         [-3.1816,  1.6611,  6.4609, -8.6875,  6.8477],
         [ 3.2344,  6.1016,  4.6055, -6.5547, -7.1016],
         [-1.3887, -1.3359,  7.0039,  2.3906,  7.7344],
         [-8.4609, -3.7695,  7.7266, -2.6016,  2.9961],
         [-5.5195, -2.8652,  0.1582, -7.4531,  5.3711],
         [-5.5273, -5.0977,  2.0391,  8.0312, -8.3438]],

        [[-6.6797, -5.0078,  6.9883, -2.6992, -2.5488],
         [ 3.3672,  4.9492, -1.5293,  4.4375,  3.2168],
         [-7.7422, -1.2305,  0.5977,  0.8613, -5.5273],
         [ 0.2109, -0.6768, -1.1777,  4.1133,  4.2461],
         [ 4.9570,  1.3184,  1.2568, -4.0078,  0.3340],
         [-8.0469,  6.0820,  0.3604,  1.6260, -0.7910],
         [-5.5117,  1.9512, -3.4375,  1.1602,  1.4238],
         [ 1.8809, -2.5664, -6.9453,  7.3984, -3.2266],
         [-2.4082, -3.8398,  0.0703,  0.6416, -0.6240],
         [-4.6836, -0.9844,  2.2148,  2.0117,  4.5547]]], dtype=torch.float16))
kwargs = {'rounding_mode': 'trunc'}

Expected output

expected = tensor([[[  -0.,    1.,  -15.,   -0.,   -1.],
         [   0.,    2.,   -0.,   -0.,    4.],
         [  -1.,    0.,   -0.,  -54.,   -0.],
         [   0.,   -0.,   -1.,   -0.,    4.],
         [   2.,    0.,    5.,    0.,   -1.],
         [  -0.,   -2.,   -0.,    7.,    0.],
         [   3.,    0.,    1.,   -0.,   -1.],
         [  -4.,   -2.,   -0.,   -0.,   11.],
         [   0.,    0.,    0.,   18.,    5.],
         [   2.,   -3.,    0.,   -0.,   13.]],

        [[  -3.,   42.,   -0.,   -0.,   -2.],
         [ -55.,    0.,    3.,   -8.,    1.],
         [  -1.,    0.,    0.,   -0.,   -0.],
         [   3.,    1.,    1.,   -0.,   -0.],
         [   0.,    1.,   -0.,    0.,   -3.],
         [   2.,   -0.,    3.,   -0.,   -1.],
         [   0.,    0.,    1.,    3.,   88.],
         [   1.,   -0.,    1.,   51.,   -1.],
         [   1.,   -1.,   -0.,    0.,    0.],
         [   5.,    0.,   -0.,   -2.,   -0.]],

        [[  -0.,   -2.,   -0.,    0.,    0.],
         [  -0.,   -2.,   -0.,    1.,   -1.],
         [   0.,    0.,   -1.,   -1.,    0.],
         [   0.,  -12.,   -1.,   -1.,   -1.],
         [  -0.,   -1.,    1.,   -0.,   -2.],
         [   1.,    1.,   -0.,    1.,    0.],
         [  -0.,    3.,    1.,   -0.,    1.],
         [   1.,  -53.,   -0.,   -0.,    2.],
         [  -1.,    4.,   -5.,    1.,   -0.],
         [  -0.,    4.,   -1.,   -0.,    1.]],

        [[   1.,    1., -235.,    0.,   -1.],
         [  -0.,    0.,    1.,    5.,   -0.],
         [  -3.,   -0.,    1.,    0.,    0.],
         [ -17.,   48.,    0.,   -0.,    0.],
         [  -1.,   -1.,   -0.,   -0.,    1.],
         [  -0.,   -0.,   -0.,    0.,   -0.],
         [   0.,   -5.,   -1.,   -1.,   -1.],
         [   0.,   -1.,   -0.,   -3.,   -1.],
         [   0.,    0.,   49.,    0.,    0.],
         [  -0.,   -0.,   -0.,    0.,    0.]],

        [[  -1.,    0.,   -1.,   -2.,   -2.],
         [  -1.,    0.,   -0.,   -1.,   -1.],
         [   0.,   -2.,  -10.,    7.,    0.],
         [  -7.,  -10.,   -0.,   -0.,   -1.],
         [  -0.,   -0.,   -5.,   -0.,   15.],
         [   0.,    0.,  -24.,    0.,   -2.],
         [   0.,   -0.,    0.,   -1.,    2.],
         [  -4.,    0.,    0.,   -0.,   -2.],
         [   1.,    1.,  -81.,   -5.,   -1.],
         [  -1.,   -5.,   -1.,   -3.,   -0.]]], dtype=torch.float16)

Shape: torch.Size([5, 10, 5])

Actual output

actual = tensor([[[   0.,    1.,  -15.,    0.,   -1.],
         [   0.,    2.,    0.,    0.,    4.],
         [  -1.,    0.,    0.,  -54.,    0.],
         [   0.,    0.,   -1.,    0.,    4.],
         [   2.,    0.,    5.,    0.,   -1.],
         [   0.,   -2.,    0.,    7.,    0.],
         [   3.,    0.,    1.,    0.,   -1.],
         [  -4.,   -2.,    0.,    0.,   11.],
         [   0.,    0.,    0.,   18.,    5.],
         [   2.,   -3.,    0.,    0.,   13.]],

        [[  -3.,   42.,    0.,    0.,   -2.],
         [ -55.,    0.,    3.,   -8.,    1.],
         [  -1.,    0.,    0.,    0.,    0.],
         [   3.,    1.,    1.,    0.,    0.],
         [   0.,    1.,    0.,    0.,   -3.],
         [   2.,    0.,    3.,    0.,   -1.],
         [   0.,    0.,    1.,    3.,   88.],
         [   1.,    0.,    1.,   51.,   -1.],
         [   1.,   -1.,    0.,    0.,    0.],
         [   5.,    0.,    0.,   -2.,    0.]],

        [[   0.,   -2.,    0.,    0.,    0.],
         [   0.,   -2.,    0.,    1.,   -1.],
         [   0.,    0.,   -1.,   -1.,    0.],
         [   0.,  -12.,   -1.,   -1.,   -1.],
         [   0.,   -1.,    1.,    0.,   -2.],
         [   1.,    1.,    0.,    1.,    0.],
         [   0.,    3.,    1.,    0.,    1.],
         [   1.,  -53.,    0.,    0.,    2.],
         [  -1.,    4.,   -5.,    1.,    0.],
         [   0.,    4.,   -1.,    0.,    1.]],

        [[   1.,    1., -235.,    0.,   -1.],
         [   0.,    0.,    1.,    5.,    0.],
         [  -3.,    0.,    1.,    0.,    0.],
         [ -17.,   48.,    0.,    0.,    0.],
         [  -1.,   -1.,    0.,    0.,    1.],
         [   0.,    0.,    0.,    0.,    0.],
         [   0.,   -5.,   -1.,   -1.,   -1.],
         [   0.,   -1.,    0.,   -3.,   -1.],
         [   0.,    0.,   48.,    0.,    0.],
         [   0.,    0.,    0.,    0.,    0.]],

        [[  -1.,    0.,   -1.,   -2.,   -2.],
         [  -1.,    0.,    0.,   -1.,   -1.],
         [   0.,   -2.,  -10.,    7.,    0.],
         [  -7.,  -10.,    0.,    0.,   -1.],
         [   0.,    0.,   -5.,    0.,   15.],
         [   0.,    0.,  -24.,    0.,   -2.],
         [   0.,    0.,    0.,   -1.,    2.],
         [  -4.,    0.,    0.,    0.,   -2.],
         [   1.,    1.,  -81.,   -5.,   -1.],
         [  -1.,   -5.,   -1.,   -3.,    0.]]], dtype=torch.float16)

Shape: torch.Size([5, 10, 5])

Difference

--- actual
+++ expected
@@ -1,54 +1,54 @@
-tensor([[[   0.,    1.,  -15.,    0.,   -1.],
-         [   0.,    2.,    0.,    0.,    4.],
-         [  -1.,    0.,    0.,  -54.,    0.],
-         [   0.,    0.,   -1.,    0.,    4.],
+tensor([[[  -0.,    1.,  -15.,   -0.,   -1.],
+         [   0.,    2.,   -0.,   -0.,    4.],
+         [  -1.,    0.,   -0.,  -54.,   -0.],
+         [   0.,   -0.,   -1.,   -0.,    4.],
          [   2.,    0.,    5.,    0.,   -1.],
-         [   0.,   -2.,    0.,    7.,    0.],
-         [   3.,    0.,    1.,    0.,   -1.],
-         [  -4.,   -2.,    0.,    0.,   11.],
+         [  -0.,   -2.,   -0.,    7.,    0.],
+         [   3.,    0.,    1.,   -0.,   -1.],
+         [  -4.,   -2.,   -0.,   -0.,   11.],
          [   0.,    0.,    0.,   18.,    5.],
-         [   2.,   -3.,    0.,    0.,   13.]],
+         [   2.,   -3.,    0.,   -0.,   13.]],

-        [[  -3.,   42.,    0.,    0.,   -2.],
+        [[  -3.,   42.,   -0.,   -0.,   -2.],
          [ -55.,    0.,    3.,   -8.,    1.],
-         [  -1.,    0.,    0.,    0.,    0.],
-         [   3.,    1.,    1.,    0.,    0.],
-         [   0.,    1.,    0.,    0.,   -3.],
-         [   2.,    0.,    3.,    0.,   -1.],
+         [  -1.,    0.,    0.,   -0.,   -0.],
+         [   3.,    1.,    1.,   -0.,   -0.],
+         [   0.,    1.,   -0.,    0.,   -3.],
+         [   2.,   -0.,    3.,   -0.,   -1.],
          [   0.,    0.,    1.,    3.,   88.],
-         [   1.,    0.,    1.,   51.,   -1.],
-         [   1.,   -1.,    0.,    0.,    0.],
-         [   5.,    0.,    0.,   -2.,    0.]],
+         [   1.,   -0.,    1.,   51.,   -1.],
+         [   1.,   -1.,   -0.,    0.,    0.],
+         [   5.,    0.,   -0.,   -2.,   -0.]],

-        [[   0.,   -2.,    0.,    0.,    0.],
-         [   0.,   -2.,    0.,    1.,   -1.],
+        [[  -0.,   -2.,   -0.,    0.,    0.],
+         [  -0.,   -2.,   -0.,    1.,   -1.],
          [   0.,    0.,   -1.,   -1.,    0.],
          [   0.,  -12.,   -1.,   -1.,   -1.],
-         [   0.,   -1.,    1.,    0.,   -2.],
-         [   1.,    1.,    0.,    1.,    0.],
-         [   0.,    3.,    1.,    0.,    1.],
-         [   1.,  -53.,    0.,    0.,    2.],
-         [  -1.,    4.,   -5.,    1.,    0.],
-         [   0.,    4.,   -1.,    0.,    1.]],
+         [  -0.,   -1.,    1.,   -0.,   -2.],
+         [   1.,    1.,   -0.,    1.,    0.],
+         [  -0.,    3.,    1.,   -0.,    1.],
+         [   1.,  -53.,   -0.,   -0.,    2.],
+         [  -1.,    4.,   -5.,    1.,   -0.],
+         [  -0.,    4.,   -1.,   -0.,    1.]],

         [[   1.,    1., -235.,    0.,   -1.],
-         [   0.,    0.,    1.,    5.,    0.],
-         [  -3.,    0.,    1.,    0.,    0.],
-         [ -17.,   48.,    0.,    0.,    0.],
-         [  -1.,   -1.,    0.,    0.,    1.],
-         [   0.,    0.,    0.,    0.,    0.],
+         [  -0.,    0.,    1.,    5.,   -0.],
+         [  -3.,   -0.,    1.,    0.,    0.],
+         [ -17.,   48.,    0.,   -0.,    0.],
+         [  -1.,   -1.,   -0.,   -0.,    1.],
+         [  -0.,   -0.,   -0.,    0.,   -0.],
          [   0.,   -5.,   -1.,   -1.,   -1.],
-         [   0.,   -1.,    0.,   -3.,   -1.],
-         [   0.,    0.,   48.,    0.,    0.],
-         [   0.,    0.,    0.,    0.,    0.]],
+         [   0.,   -1.,   -0.,   -3.,   -1.],
+         [   0.,    0.,   49.,    0.,    0.],
+         [  -0.,   -0.,   -0.,    0.,    0.]],

         [[  -1.,    0.,   -1.,   -2.,   -2.],
-         [  -1.,    0.,    0.,   -1.,   -1.],
+         [  -1.,    0.,   -0.,   -1.,   -1.],
          [   0.,   -2.,  -10.,    7.,    0.],
-         [  -7.,  -10.,    0.,    0.,   -1.],
-         [   0.,    0.,   -5.,    0.,   15.],
+         [  -7.,  -10.,   -0.,   -0.,   -1.],
+         [  -0.,   -0.,   -5.,   -0.,   15.],
          [   0.,    0.,  -24.,    0.,   -2.],
-         [   0.,    0.,    0.,   -1.,    2.],
-         [  -4.,    0.,    0.,    0.,   -2.],
+         [   0.,   -0.,    0.,   -1.,    2.],
+         [  -4.,    0.,    0.,   -0.,   -2.],
          [   1.,    1.,  -81.,   -5.,   -1.],
-         [  -1.,   -5.,   -1.,   -3.,    0.]]], dtype=torch.float16)
+         [  -1.,   -5.,   -1.,   -3.,   -0.]]], dtype=torch.float16)

Full error stack

Tensor-likes are not close!

Mismatched elements: 1 / 250 (0.4%)
Greatest absolute difference: 1.0 at index (3, 8, 2) (up to 1e-05 allowed)
Greatest relative difference: 0.0204010009765625 at index (3, 8, 2) (up to 0.001 allowed)
  File "/home/justinchu/dev/onnx-script/onnxscript/tests/function_libs/torch_lib/ops_test.py", line 259, in run_test_output_match
    torch.testing.assert_close(
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/testing/_comparison.py", line 1511, in assert_close
    raise error_metas[0].to_error(msg)