pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
82.35k stars 22.15k forks source link

Getting "RuntimeError: Trying to backward through the graph a second time" when calling backward on compiled unbind / split ops #118739

Open Cztery opened 7 months ago

Cztery commented 7 months ago

🐛 Describe the bug

Since Pytorch 2.2 (until now - current PT nightly) trying to run backward() on compiled unbind or split ops, a RuntimeError is thrown. As suggested in the error message, calling backward() with retain_graph=True prevents the error from occuring. However previous PT versions (2.1) did not require that to work. Is this an expected bahaviour or a new bug that a workaround with retain_graph=True hides?

Error logs

➜ ~ python3 unbind_repro.py Traceback (most recent call last): File "/home/ubu/unbind_repro.py", line 14, in i.backward() #retain_graph=True) ^^^^^^^^^^^^ File "/home/ubu/.local/lib/python3.11/site-packages/torch/_tensor.py", line 524, in backward torch.autograd.backward( File "/home/ubu/.local/lib/python3.11/site-packages/torch/autograd/init.py", line 267, in backward _engine_run_backward( File "/home/ubu/.local/lib/python3.11/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/ubu/.local/lib/python3.11/site-packages/torch/autograd/function.py", line 294, in apply return user_fn(self, args) ^^^^^^^^^^^^^^^^^^^^ File "/home/ubu/.local/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 620, in backward ctx.saved_tensors, ^^^^^^^^^^^^^^^^^ RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

Minified repro

import torch

input = torch.Tensor([1, 2, 3, 4, 8])
input.requires_grad = True

def foo(input):
  outputs = torch.unbind(input, 0)
  return outputs

fooCompiled = torch.compile(foo)
out = fooCompiled(input)

for i in out:
  i.backward() #retain_graph=True)

print(out)

Versions

PyTorch version: 2.3.0.dev20240129+cpu Is debug build: False CUDA used to build PyTorch: Could not collect ROCM used to build PyTorch: N/A

OS: Ubuntu 23.04 (x86_64) GCC version: (Ubuntu 12.3.0-1ubuntu1~23.04) 12.3.0 Clang version: 15.0.7 CMake version: version 3.25.1 Libc version: glibc-2.37

Python version: 3.11.4 (main, Dec 7 2023, 15:43:41) [GCC 12.3.0] (64-bit runtime) Python platform: Linux-5.15.133.1-microsoft-standard-WSL2-x86_64-with-glibc2.37 Is CUDA available: False CUDA runtime version: 11.8.89 CUDA_MODULE_LOADING set to: N/A GPU models and configuration: Could not collect Nvidia driver version: Could not collect cuDNN version: Could not collect HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 39 bits physical, 48 bits virtual Byte Order: Little Endian CPU(s): 8 On-line CPU(s) list: 0-7 Vendor ID: GenuineIntel Model name: Intel(R) Core(TM) i7-8665U CPU @ 1.90GHz CPU family: 6 Model: 142 Thread(s) per core: 2 Core(s) per socket: 4 Socket(s): 1 Stepping: 12 BogoMIPS: 4224.01 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology cpuid pni pclmulqdq vmx ssse3 fma cx16 pdcm pcid sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi ept vpid ept_ad fsgsbase bmi1 avx2 smep bmi2 erms invpcid rdseed adx smap clflushopt xsaveopt xsavec xgetbv1 xsaves flush_l1d arch_capabilities Virtualization: VT-x Hypervisor vendor: Microsoft Virtualization type: full L1d cache: 128 KiB (4 instances) L1i cache: 128 KiB (4 instances) L2 cache: 1 MiB (4 instances) L3 cache: 8 MiB (1 instance) Vulnerability Gather data sampling: Unknown: Dependent on hypervisor status Vulnerability Itlb multihit: KVM: Mitigation: VMX disabled Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown Vulnerability Retbleed: Mitigation; Enhanced IBRS Vulnerability Spec rstack overflow: Not affected Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence Vulnerability Srbds: Unknown: Dependent on hypervisor status Vulnerability Tsx async abort: Mitigation; TSX disabled

Versions of relevant libraries: [pip3] numpy==1.24.2 [pip3] torch==2.3.0.dev20240129+cpu [pip3] torchaudio==2.2.0.dev20240129+cpu [pip3] torchvision==0.18.0.dev20240129+cpu [conda] Could not collect

cc @ezyang @msaroufim @bdhirsh @anijain2305 @zou3519 @chauhang @gchanan @kadeng

BoyuanFeng commented 1 week ago

@bdhirsh (old issue scrubbing) I tried the repro and we can still get an the following error:

Traceback (most recent call last):
  File "/home/boyuan/playground/pt2/repro.py", line 14, in <module>
    i.backward() #retain_graph=True)
  File "/data/users/boyuan/pytorch/torch/_tensor.py", line 581, in backward
    torch.autograd.backward(
  File "/data/users/boyuan/pytorch/torch/autograd/__init__.py", line 347, in backward
    _engine_run_backward(
  File "/data/users/boyuan/pytorch/torch/autograd/graph.py", line 825, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/data/users/boyuan/pytorch/torch/autograd/function.py", line 307, in apply
    return user_fn(self, *args)
  File "/data/users/boyuan/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1740, in backward
    ctx_saved_tensors = ctx.saved_tensors
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.