Open fynnsu opened 9 months ago
The rrelu_with_noise() missing 2 required positional arguments: 'lower' and 'upper'
problem is a duplicate of https://github.com/pytorch/pytorch/issues/115811 and should be fixed already. The torch.nn.RReLU
version still doesn't work though.
The problem seems to be that torch.nn.RReLU
defaults to training = True
and therefore takes another code path than torch.nn.functional.rrelu
. This code path then inplace copies to the RReLU noise tensor, which triggers the assert. I've put up a PR that sidesteps the problem, but that really needs to be checked carefully as it violates the invariant that is mentioned in the code comments.
I think the issue may be that the decomp for torch._ops.aten.rrelu_with_noise.default
has an in-place op copy_
. Since FunctionalTensorMode runs above ProxyTensorMode (which runs the decomps), maybe its good to say that generally decomps for out-of-place ops shouldn't be allowed to have in-place operations.
i have seen similar issue with 'torch.ops.aten._scaled_dot_product_flash_attention_for_cpu', in our coustm backend , so after removing this from decomposition table, there is no error.
🐛 Describe the bug
Dynamo fails to produce the graph for RReLU.
It fails when checking if this graph is functional.
However, this graph seems to contain several unnecessary operations (including the one that fails
copy_
). The graph below should be sufficient to compute RReLU operation.I initially thought the extra lines might be for the
inplace=True
graph but they aren't used in that graph either. (below)Error logs
Minified repro
I also tried using the
torch.rrelu
fn directly and ran into a different error (rrelu_with_noise() missing 2 required positional arguments: 'lower' and 'upper'
) doing so (This might be a separate bug):Versions
PyTorch version: 2.2.0+cu121 Is debug build: False CUDA used to build PyTorch: 12.1 ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.1 LTS (x86_64) GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 Clang version: Could not collect CMake version: Could not collect Libc version: glibc-2.35
Python version: 3.9.0 (default, Nov 15 2020, 14:28:56) [GCC 7.3.0] (64-bit runtime) Python platform: Linux-6.2.0-37-generic-x86_64-with-glibc2.35 Is CUDA available: True CUDA runtime version: Could not collect CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: NVIDIA GeForce GTX 1080 Ti Nvidia driver version: 535.113.01 cuDNN version: Could not collect HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True
CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 43 bits physical, 48 bits virtual Byte Order: Little Endian CPU(s): 16 On-line CPU(s) list: 0-15 Vendor ID: AuthenticAMD Model name: AMD Ryzen 7 3800X 8-Core Processor CPU family: 23 Model: 113 Thread(s) per core: 2 Core(s) per socket: 8 Socket(s): 1 Stepping: 0 Frequency boost: enabled CPU max MHz: 4558.8862 CPU min MHz: 2200.0000 BogoMIPS: 7800.12 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip rdpid overflow_recov succor smca sev sev_es Virtualization: AMD-V L1d cache: 256 KiB (8 instances) L1i cache: 256 KiB (8 instances) L2 cache: 4 MiB (8 instances) L3 cache: 32 MiB (2 instances) NUMA node(s): 1 NUMA node0 CPU(s): 0-15 Vulnerability Gather data sampling: Not affected Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Retbleed: Mitigation; untrained return thunk; SMT enabled with STIBP protection Vulnerability Spec rstack overflow: Mitigation; safe RET Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected
Versions of relevant libraries: [pip3] mypy==1.8.0 [pip3] mypy-extensions==1.0.0 [pip3] numpy==1.26.3 [pip3] torch==2.2.0 [pip3] torchaudio==2.2.0 [pip3] torchvision==0.17.0 [pip3] torchviz==0.0.2 [pip3] triton==2.2.0 [conda] numpy 1.26.3 pypi_0 pypi [conda] torch 2.2.0 pypi_0 pypi [conda] torchaudio 2.2.0 pypi_0 pypi [conda] torchvision 0.17.0 pypi_0 pypi [conda] torchviz 0.0.2 pypi_0 pypi [conda] triton 2.2.0 pypi_0 pypi
cc @ezyang @msaroufim @bdhirsh @anijain2305 @zou3519