pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
82.72k stars 22.28k forks source link

torch.compiled model output gets overwritten despite tensor.detach() #104435

Closed HDCharles closed 2 months ago

HDCharles commented 1 year ago

🐛 Describe the bug

related to https://github.com/pytorch/pytorch/blob/5ab1d2c2cc4e9c83b15c98974d6610a03322f40e/torch/_inductor/cudagraph_trees.py#L1889-L1893

at times when you would get this error, if instead of doing out = model(input), you do out = model(input).detach() to try to fix the error, you suppress the error message while not fixing the problem. Specifically the value of out will change if you run model(input).detach() again. you have to do model(input)+0 or something similar to actually fix the problem.

at a high level i think this bug is either A) about tensor.detach() suppressing an error message without fixing the error. B) model outputs getting overwritten despite tensor.detach() depending on whether B is expected or not.

either the error message should not be suppressed or the output value should function as expected.

@eellison

Error logs

n/a

Minified repro

n/a

my own repro, try running with/without @torch.compile() and with/without .detach() running as is should either throw the error message or give the same result as running without @torch.compile

@torch.compile(mode='reduce-overhead')
def foo(x):
    return x * x * x

inp = torch.rand([2], device="cuda")
out = foo(inp).detach()
sum_val_1 = out+out
out2 = foo(inp).detach()
sum_val_2 = out+out
print(sum_val_1, sum_val_2, out2 + out2)
assert  sum_val_1.sum()==sum_val_2.sum()

Versions

Collecting environment information... PyTorch version: 2.1.0a0+git1dba81f Is debug build: False CUDA used to build PyTorch: 11.7 ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.5 LTS (x86_64) GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0 Clang version: 10.0.0-4ubuntu1 CMake version: version 3.26.1 Libc version: glibc-2.31

Python version: 3.9.5 (default, Jun 4 2021, 12:28:51) [GCC 7.5.0] (64-bit runtime) Python platform: Linux-5.15.0-1019-aws-x86_64-with-glibc2.31 Is CUDA available: False CUDA runtime version: 11.7.64 CUDA_MODULE_LOADING set to: N/A GPU models and configuration: Could not collect Nvidia driver version: Could not collect cuDNN version: Could not collect HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Byte Order: Little Endian Address sizes: 46 bits physical, 48 bits virtual CPU(s): 96 On-line CPU(s) list: 0-95 Thread(s) per core: 2 Core(s) per socket: 24 Socket(s): 2 NUMA node(s): 2 Vendor ID: GenuineIntel CPU family: 6 Model: 85 Model name: Intel(R) Xeon(R) Platinum 8259CL CPU @ 2.50GHz Stepping: 7 CPU MHz: 2500.000 BogoMIPS: 5000.00 Hypervisor vendor: KVM Virtualization type: full L1d cache: 1.5 MiB L1i cache: 1.5 MiB L2 cache: 48 MiB L3 cache: 71.5 MiB NUMA node0 CPU(s): 0-23,48-71 NUMA node1 CPU(s): 24-47,72-95 Vulnerability Itlb multihit: KVM: Mitigation: VMX unsupported Vulnerability L1tf: Mitigation; PTE Inversion Vulnerability Mds: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown Vulnerability Meltdown: Mitigation; PTI Vulnerability Mmio stale data: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown Vulnerability Retbleed: Vulnerable Vulnerability Spec store bypass: Vulnerable Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Retpolines, STIBP disabled, RSB filling Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves ida arat pku ospke

Versions of relevant libraries: [pip3] flake8==6.0.0 [pip3] flake8-bugbear==23.3.23 [pip3] flake8-comprehensions==3.12.0 [pip3] flake8-executable==2.1.3 [pip3] flake8-logging-format==0.9.0 [pip3] flake8-pyi==23.3.1 [pip3] flake8-simplify==0.19.3 [pip3] mypy==0.960 [pip3] mypy-extensions==0.4.3 [pip3] numpy==1.23.1 [pip3] pytorch-triton==2.1.0+440fd1bf20 [pip3] torch==2.1.0a0+git1dba81f [pip3] torchvision==0.16.0a0+e5bf7cf [conda] blas 1.0 mkl
[conda] mkl 2021.4.0 h06a4308_640
[conda] mkl-include 2023.0.0 h06a4308_25399
[conda] mkl-service 2.4.0 py39h7f8727e_0
[conda] mkl_fft 1.3.1 py39hd3c417c_0
[conda] mkl_random 1.2.2 py39h51133e4_0
[conda] numpy 1.23.1 pypi_0 pypi [conda] pytorch-triton 2.1.0+440fd1bf20 pypi_0 pypi [conda] torch 2.1.0a0+git1dba81f dev_0 [conda] torchvision 0.16.0a0+e5bf7cf dev_0

cc @mcarilli @ezyang @eellison @peterbell10 @msaroufim @bdhirsh @anijain2305 @zou3519 @chauhang @wconstab

eellison commented 1 year ago

Thank you for the issue! I can extend this to cover Storages, and not just tensors.

Edit: hmm, may need to look further to see how feasible this is.. might be tricky or require other solutions

mlazos commented 6 months ago

@eellison any update on this?

eellison commented 5 months ago

To fix this we'd want to extend this pr so that it also sets an error message on StorageImpl that throws when called.

isuruf commented 5 months ago

To fix this we'd want to extend https://github.com/pytorch/pytorch/pull/100927 so that it also sets an error message on StorageImpl that throws when called.

Just to clarify, you mean any time data_ptr_ of a StorageImpl is accessed, we should check if there's an error message and throw right?

Also, should this also be using an indirection like ExtraMeta in TensorImpl ?

eellison commented 5 months ago

Yes, that's correct. From convo with: @bdhirsh

you just want a version of tensor.set_storage_access_should_throw that is really:
...
"given a tensor, set a bit on the storage to indicate to all aliases that their storage is invalid / should throw"
...
I think the only spicy bit is that it feels like to do that we'll need an extra_meta_ field on the c10::StorageImpl, not just TensorImpl (that contains the error message + the bit saying whether we should flip)