pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
84.02k stars 22.65k forks source link

ONNX Export - miscompilation for complex-valued operators #113444

Open timstokman opened 1 year ago

timstokman commented 1 year ago

🐛 Describe the bug

Initial issue fixed, changed description.

I'm trying to export a model to ONNX with torch.onnx.dynamo_export. It's a bit of an experimental operator and it has some complex-valued tensors and fft operations. Trying to experiment with deployment modes. I'm getting large accuracy issues.

import torch
import torch.nn as nn
import onnxruntime
import numpy as np

def fftconv(u, k, D):
    seqlen = u.shape[-1]
    fft_size = 2 * seqlen

    k_f = torch.fft.rfft(k, n=fft_size) / fft_size
    u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size)

    if len(u.shape) > 3: k_f = k_f.unsqueeze(1)
    y = torch.fft.irfft(u_f * k_f, n=fft_size, norm='forward')[..., :seqlen]

    out = y + u * D.unsqueeze(-1)
    return out.to(dtype=u.dtype)

class Filter(nn.Module):
    def forward(self, x, k=None, bias=None):
        y = fftconv(x, k, bias)
        return y

if __name__ == "__main__":
    filter = Filter().eval()
    x_input, k_input, bias_input = torch.rand(1, 512, 1024), torch.rand(512, 1024), torch.rand(512)
    export_output = torch.onnx.dynamo_export(filter, x_input, k_input, bias_input)
    export_output.save('test.onnx')
    session = onnxruntime.InferenceSession('test.onnx')
    test_output = filter(x_input, k_input, bias_input)
    test_output_onnx = session.run([session.get_outputs()[0].name], {'arg0': x_input.numpy(), 'arg1': k_input.numpy(), 'arg2': bias_input.numpy()})[0]
    np.testing.assert_allclose(test_output_onnx, test_output, 0.02, 0.02)

Output:

/home/timstokman/.conda/envs/test/lib/python3.11/site-packages/torch/onnx/_internal/exporter.py:130: UserWarning: torch.onnx.dynamo_export only implements opset version 18 for now. If you need to use a different opset version, please register them with register_custom_op.
  warnings.warn(
/home/timstokman/.conda/envs/test/lib/python3.11/site-packages/onnxscript/function_libs/torch_lib/graph_building.py:972: UserWarning: ONNX model is invalid: [ShapeInferenceError] (op_type:Squeeze, node name: Squeeze_41): [ShapeInferenceError] Inferred shape and existing shape differ in dimension 2: (1025) vs (2048)
  warnings.warn(f"ONNX model is invalid: {e}", stacklevel=1)
2023-11-29 19:06:53.076644636 [W:onnxruntime:, graph.cc:108 MergeShapeInfo] Error merging shape info for output. '_fft_c2r' source:{1,512,1025} target:{1,512,2048}. Falling back to lenient merge.
Traceback (most recent call last):
  File "/home/timstokman/Development/test/repro.py", line 35, in <module>
    np.testing.assert_allclose(test_output_onnx, test_output, 0.02, 0.02)
  File "/home/timstokman/.conda/envs/test/lib/python3.11/site-packages/numpy/testing/_private/utils.py", line 1504, in assert_allclose
    assert_array_compare(compare, actual, desired, err_msg=str(err_msg),
  File "/home/timstokman/.conda/envs/test/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/timstokman/.conda/envs/test/lib/python3.11/site-packages/numpy/testing/_private/utils.py", line 797, in assert_array_compare
    raise AssertionError(msg)
AssertionError: 
Not equal to tolerance rtol=0.02, atol=0.02

Mismatched elements: 517665 / 524288 (98.7%)
Max absolute difference: 207.37213
Max relative difference: 9381.009
 x: array([[[63.46427 , 63.431755, 63.74789 , ..., 64.399574, 63.802116,
         63.299534],
        [65.65336 , 66.02201 , 66.56686 , ..., 66.12466 , 65.83544 ,...
 y: array([[[1.080449e+00, 1.316129e+00, 1.437861e+00, ..., 2.505085e+02,
         2.560666e+02, 2.489434e+02],
        [1.025918e+00, 9.075544e-01, 1.936453e+00, ..., 2.610434e+02,...

Versions

Collecting environment information... PyTorch version: 2.1.0 Is debug build: False CUDA used to build PyTorch: 11.8 ROCM used to build PyTorch: N/A

OS: Debian GNU/Linux 12 (bookworm) (x86_64) GCC version: (Debian 11.3.0-12) 11.3.0 Clang version: Could not collect CMake version: version 3.27.7 Libc version: glibc-2.36

Python version: 3.11.5 (main, Sep 11 2023, 13:54:46) [GCC 11.2.0] (64-bit runtime) Python platform: Linux-6.1.0-13-amd64-x86_64-with-glibc2.36 Is CUDA available: True CUDA runtime version: 11.8.89 CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3080 Ti Nvidia driver version: 525.125.06 cuDNN version: Could not collect HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 48 bits physical, 48 bits virtual Byte Order: Little Endian CPU(s): 16 On-line CPU(s) list: 0-15 Vendor ID: AuthenticAMD Model name: AMD Ryzen 7 7700X 8-Core Processor CPU family: 25 Model: 97 Thread(s) per core: 2 Core(s) per socket: 8 Socket(s): 1 Stepping: 2 Frequency boost: enabled CPU(s) scaling MHz: 85% CPU max MHz: 5572,2651 CPU min MHz: 3000,0000 BogoMIPS: 8983,39 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid overflow_recov succor smca fsrm flush_l1d Virtualization: AMD-V L1d cache: 256 KiB (8 instances) L1i cache: 256 KiB (8 instances) L2 cache: 8 MiB (8 instances) L3 cache: 32 MiB (1 instance) NUMA node(s): 1 NUMA node0 CPU(s): 0-15 Vulnerability Gather data sampling: Not affected Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Retbleed: Not affected Vulnerability Spec rstack overflow: Mitigation; safe RET, no microcode Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected

Versions of relevant libraries: [pip3] numpy==1.26.1 [pip3] onnx==1.15.0 [pip3] onnxscript==0.1.0.dev20231108 [pip3] pytorch-lightning==1.8.6 [pip3] torch==2.1.0 [pip3] torchdata==0.7.0 [pip3] torchmetrics==1.2.0 [pip3] torchtext==0.16.0 [pip3] torchvision==0.16.0 [pip3] triton==2.1.0 [conda] blas 1.0 mkl
[conda] mkl 2023.1.0 h213fc3f_46343
[conda] numpy 1.26.1 pypi_0 pypi [conda] pytorch 2.1.0 py3.11_cuda11.8_cudnn8.7.0_0 pytorch [conda] pytorch-cuda 11.8 h7e8668a_5 pytorch [conda] pytorch-lightning 1.8.6 pypi_0 pypi [conda] pytorch-mutex 1.0 cuda pytorch [conda] torchdata 0.7.0 pypi_0 pypi [conda] torchmetrics 1.2.0 pypi_0 pypi [conda] torchtext 0.16.0 pypi_0 pypi [conda] torchtriton 2.1.0 py311 pytorch [conda] torchvision 0.16.0 pypi_0 pypi

BowenBao commented 1 year ago

cc @justinchuby for complex ops export

justinchuby commented 1 year ago

Hi! Thanks for reporting the error. torch.onnx.dynamo_export is the new export API we are developing and the supported operators are not yet documented (the documentation is for the torch.onnx.export API). We are going to add complex support to arithmetic operators in the coming days.

justinchuby commented 1 year ago

This should be the fix: https://github.com/microsoft/onnxscript/pull/1144

timstokman commented 1 year ago

Wow, that's amazingly fast, thanks!

justinchuby commented 1 year ago

Also need https://github.com/microsoft/onnxscript/pull/1147

justinchuby commented 1 year ago

Exported image

thiagocrepaldi commented 11 months ago

The provided repro worked using latest pytorch and onnxscript main branches. Can you double check @timstokman ? Feel free to close the issue if it works for you.

timstokman commented 11 months ago

It's now being miscompiled @thiagocrepaldi:

import torch
import torch.nn as nn
import onnxruntime
import numpy as np

def fftconv(u, k, D):
    seqlen = u.shape[-1]
    fft_size = 2 * seqlen

    k_f = torch.fft.rfft(k, n=fft_size) / fft_size
    u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size)

    if len(u.shape) > 3: k_f = k_f.unsqueeze(1)
    y = torch.fft.irfft(u_f * k_f, n=fft_size, norm='forward')[..., :seqlen]

    out = y + u * D.unsqueeze(-1)
    return out.to(dtype=u.dtype)

class Filter(nn.Module):
    def forward(self, x, k=None, bias=None):
        y = fftconv(x, k, bias)
        return y

if __name__ == "__main__":
    filter = Filter().eval()
    x_input, k_input, bias_input = torch.rand(1, 512, 1024), torch.rand(512, 1024), torch.rand(512)
    export_output = torch.onnx.dynamo_export(filter, x_input, k_input, bias_input)
    export_output.save('test.onnx')
    session = onnxruntime.InferenceSession('test.onnx')
    test_output = filter(x_input, k_input, bias_input)
    test_output_onnx = session.run([session.get_outputs()[0].name], {'arg0': x_input.numpy(), 'arg1': k_input.numpy(), 'arg2': bias_input.numpy()})[0]
    np.testing.assert_allclose(test_output_onnx, test_output, 0.02, 0.02)

Will give:

/home/timstokman/.conda/envs/test/lib/python3.11/site-packages/torch/onnx/_internal/exporter.py:130: UserWarning: torch.onnx.dynamo_export only implements opset version 18 for now. If you need to use a different opset version, please register them with register_custom_op.
  warnings.warn(
/home/timstokman/.conda/envs/test/lib/python3.11/site-packages/onnxscript/function_libs/torch_lib/graph_building.py:972: UserWarning: ONNX model is invalid: [ShapeInferenceError] (op_type:Squeeze, node name: Squeeze_41): [ShapeInferenceError] Inferred shape and existing shape differ in dimension 2: (1025) vs (2048)
  warnings.warn(f"ONNX model is invalid: {e}", stacklevel=1)
2023-11-29 19:06:53.076644636 [W:onnxruntime:, graph.cc:108 MergeShapeInfo] Error merging shape info for output. '_fft_c2r' source:{1,512,1025} target:{1,512,2048}. Falling back to lenient merge.
Traceback (most recent call last):
  File "/home/timstokman/Development/test/repro.py", line 35, in <module>
    np.testing.assert_allclose(test_output_onnx, test_output, 0.02, 0.02)
  File "/home/timstokman/.conda/envs/test/lib/python3.11/site-packages/numpy/testing/_private/utils.py", line 1504, in assert_allclose
    assert_array_compare(compare, actual, desired, err_msg=str(err_msg),
  File "/home/timstokman/.conda/envs/test/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/timstokman/.conda/envs/test/lib/python3.11/site-packages/numpy/testing/_private/utils.py", line 797, in assert_array_compare
    raise AssertionError(msg)
AssertionError: 
Not equal to tolerance rtol=0.02, atol=0.02

Mismatched elements: 517665 / 524288 (98.7%)
Max absolute difference: 207.37213
Max relative difference: 9381.009
 x: array([[[63.46427 , 63.431755, 63.74789 , ..., 64.399574, 63.802116,
         63.299534],
        [65.65336 , 66.02201 , 66.56686 , ..., 66.12466 , 65.83544 ,...
 y: array([[[1.080449e+00, 1.316129e+00, 1.437861e+00, ..., 2.505085e+02,
         2.560666e+02, 2.489434e+02],
        [1.025918e+00, 9.075544e-01, 1.936453e+00, ..., 2.610434e+02,...

Quite a bit better, but not quite there yet. It's not just the maximum, the average difference is also quite big.

timstokman commented 11 months ago

Versions used for testing:

Brotli             1.0.9
certifi            2023.11.17
cffi               1.16.0
charset-normalizer 2.0.4
coloredlogs        15.0.1
cryptography       41.0.3
filelock           3.13.1
flatbuffers        23.5.26
gmpy2              2.1.2
humanfriendly      10.0
idna               3.4
Jinja2             3.1.2
MarkupSafe         2.1.1
mkl-fft            1.3.8
mkl-random         1.2.4
mkl-service        2.4.0
mpmath             1.3.0
networkx           3.1
numpy              1.26.0
onnx               1.15.0
onnxruntime        1.16.3
onnxscript         0.1.0.dev20231129
packaging          23.2
Pillow             10.0.1
pip                23.3.1
protobuf           4.25.1
pycparser          2.21
pyOpenSSL          23.2.0
PySocks            1.7.1
PyYAML             6.0.1
requests           2.31.0
setuptools         68.0.0
sympy              1.11.1
torch              2.1.1
torchaudio         2.1.1
torchvision        0.16.1
triton             2.1.0
typing_extensions  4.7.1
urllib3            1.26.18
wheel              0.41.2
timstokman commented 11 months ago

Do I need to create a new item for this?

thiagocrepaldi commented 11 months ago

Do I need to create a new item for this?

not needed, just update title and description. Not the root cause for your case, but you are probably better off not using <= and using torch.testing.assert_close or similar for float comparison

@justinchuby there seems to have a huge number discrepancie. Any chance we missed something?

timstokman commented 11 months ago

Do I need to create a new item for this?

not needed, just update title and description. Not the root cause for your case, but you are probably better off not using <= and using torch.testing.assert_close or similar for float comparison

Both done

justinchuby commented 11 months ago

Doesn’t look right. Will look deeper

titaiwangms commented 11 months ago

@timstokman you can invoke onnxruntime by simply

if __name__ == "__main__":
    filter = Filter().eval()
    args = torch.rand(1, 512, 1024), torch.rand(512, 1024), torch.rand(512)
    export_output = torch.onnx.dynamo_export(filter, *args, export_options=torch.onnx.ExportOptions(op_level_debug=True))
    export_output.save_diagnostics('test_fft.sarif')
    test_output = filter(*args)
    test_output_onnx_format = export_output.adapt_torch_outputs_to_onnx(test_output)
    test_output_onnx = export_output(*args)
    np.testing.assert_allclose(test_output_onnx, test_output_onnx_format, 0.02, 0.02)

With https://github.com/pytorch/pytorch/pull/114885 merged, and use nightly torch,

@justinchuby In this one, if we turn on op level debug, we can see the issue coming from _fft_c2r implementation.

fft_op_level_debug
justinchuby commented 11 months ago

We don't have the correct implementation for fft_c2r in onnxscript/torchlib yet...

justinchuby commented 2 months ago

Tracked by https://github.com/microsoft/onnxscript/pull/1844