pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
82.11k stars 22.07k forks source link

[torch.export] Torch Export produces incorrect program when python generators are used. #130975

Open JamesMBartlett opened 1 month ago

JamesMBartlett commented 1 month ago

🐛 Describe the bug

torch.export.export in strict mode seems to produce an incorrect exported program when the exported module contains a python generator.

Here is a minimal reproducible example:

import torch
class GeneratorFail(torch.nn.Module):
    def forward(self, x):
        y = [x]
        y.extend(y[-1] + 2 for _ in range(2))
        return torch.cat(y, 1)

model = GeneratorFail()
x = torch.ones((1,1))

prog = torch.export.export(model, (x,), strict=True)
print('before export: {}'.format(model(x)))
print('after export (strict=True): {}'.format(prog.module()(x)))

prog = torch.export.export(model, (x,), strict=False)
print('before export: {}'.format(model(x)))
print('after export (strict=False): {}'.format(prog.module()(x)))

On a nightly build of torch (also tried on torch==2.3), this script will produce the following output:

before export: tensor([[1., 3., 5.]])
after export (strict=True): tensor([[1., 3., 3.]])
before export: tensor([[1., 3., 5.]])
after export (strict=False): tensor([[1., 3., 5.]])

As you can see when strict=True, the output from the exported program differs from the output of the original torch module. When strict=False the exported program is correct.

In strict mode, the exported program is incorrectly binding the list y to the value of the list before the generator starts so each y[-1] + 2 has the same value for y[-1] hence the output is [1, 3, 3]. However, in the actual torch module since the loop is a generator y gets updated for each iteration of the loop, so the output is [1, 3, 5] as expected.

I'm not too familiar with the pytorch dynamo code, but if someone could point me in the right direction I'd be happy to contribute a fix.

Real-world use case

The above example may seem contrived but it is used in ultralytics' YOLOv8. So that model won't currently export correctly with torch.export.export.

Versions

Collecting environment information... PyTorch version: 2.5.0.dev20240717+cpu Is debug build: False CUDA used to build PyTorch: Could not collect ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.4 LTS (x86_64) GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 Clang version: 15.0.6 (https://github.com/llvm/llvm-project.git 088f33605d8a61ff519c580a71b1dd57d16a03f8) CMake version: version 3.22.1 Libc version: glibc-2.35

Python version: 3.11.4 (main, Feb 26 2024, 09:39:10) [Clang 15.0.6 (https://github.com/llvm/llvm-project.git 088f33605d8a61ff519c580a (64-bit runtime) Python platform: Linux-6.5.0-21-generic-x86_64-with-glibc2.35 Is CUDA available: False CUDA runtime version: Could not collect CUDA_MODULE_LOADING set to: N/A GPU models and configuration: GPU 0: NVIDIA GeForce RTX 4090 GPU 1: NVIDIA GeForce RTX 4090

Nvidia driver version: 545.23.08 cuDNN version: Probably one of the following: /usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.7 /usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.7 /usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.7 /usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.7 /usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.7 /usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.7 /usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.7 HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 48 bits physical, 48 bits virtual Byte Order: Little Endian CPU(s): 128 On-line CPU(s) list: 0-127 Vendor ID: AuthenticAMD Model name: AMD Ryzen Threadripper PRO 5995WX 64-Cores CPU family: 25 Model: 8 Thread(s) per core: 2 Core(s) per socket: 64 Socket(s): 1 Stepping: 2 Frequency boost: enabled CPU max MHz: 7024.2178 CPU min MHz: 1800.0000 BogoMIPS: 5390.08 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin brs arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca fsrm Virtualization: AMD-V L1d cache: 2 MiB (64 instances) L1i cache: 2 MiB (64 instances) L2 cache: 32 MiB (64 instances) L3 cache: 256 MiB (8 instances) NUMA node(s): 1 NUMA node0 CPU(s): 0-127 Vulnerability Gather data sampling: Not affected Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Retbleed: Not affected Vulnerability Spec rstack overflow: Mitigation; safe RET, no microcode Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected

Versions of relevant libraries: [pip3] torch==2.5.0.dev20240717+cpu [conda] Could not collect

cc @ezyang @anijain2305 @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4

angelayi commented 1 month ago

Dynamo seems to incorrectly generate the graph when generators are used. Here are the logs when running with TORCH_LOGS="+dynamo": https://gist.github.com/angelayi/c1d455a3588a5fdce6603dae14ecbc4a

It probably has something to do with the ListIteratorVariable and how it saves y[-1] as being the same variable (highlighted part), even though this value changes through the generator. cc @williamwen42 if you have better guidance

williamwen42 commented 1 month ago

Dynamo doesn't support generators very well at the moment, although we are aiming to have better support. I also found a possibly related class of errors where dynamo unpacks iterators eagerly when it should be done lazily - I'm currently working on this https://github.com/pytorch/pytorch/issues/130750