pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
81.67k stars 21.91k forks source link

Inconsistent floating point guards with numpy scalars #130920

Open rec opened 1 month ago

rec commented 1 month ago

🐛 Describe the bug

Minimal example:

import os
os.environ["TORCH_LOGS"] = "guards"
os.environ["TORCH_TRACE"] = "."

import torch
import numpy as np

@torch._dynamo.optimize()
def func1(a, m):
    return a if m.is_integer() else 2 * a

@torch._dynamo.optimize()
def func2(a, m):
    return a if m.is_integer() else 2 * a

a = torch.ones(3, 3)

print("float then numpy")

func1(a, 2.0)
func1(a, np.float32(2.0))

print("\nnumpy then float")

func2(a, np.float32(2.0))
func2(a, 2.0)

In the first case, a guard is created which matches 2.0 by value, so it matches numpy.float32(2.0) on the second call.

In the second case, a guard is created which matches numpy.float32(2.0) by value and type, so it does not match 2.0 on the second call.

So the output looks like:

float then numpy

numpy then float
V0717 09:41:57.160523 139957071652672 torch/_dynamo/guards.py:2220] [1/0_1] [__guards] GUARDS:
V0717 09:41:57.160869 139957071652672 torch/_dynamo/guards.py:2186] [1/0_1] [__guards] 
V0717 09:41:57.160869 139957071652672 torch/_dynamo/guards.py:2186] [1/0_1] [__guards] TREE_GUARD_MANAGER:
V0717 09:41:57.160869 139957071652672 torch/_dynamo/guards.py:2186] [1/0_1] [__guards] +- RootGuardManager
V0717 09:41:57.160869 139957071652672 torch/_dynamo/guards.py:2186] [1/0_1] [__guards] | +- DEFAULT_DEVICE: utils_device.CURRENT_DEVICE == None                           # _dynamo/output_graph.py:460 in init_ambient_guards
V0717 09:41:57.160869 139957071652672 torch/_dynamo/guards.py:2186] [1/0_1] [__guards] | +- GLOBAL_STATE: ___check_global_state()
V0717 09:41:57.160869 139957071652672 torch/_dynamo/guards.py:2186] [1/0_1] [__guards] | +- GuardManager: source=L['m'], accessed_by=DictGetItemGuardAccessor(m)
V0717 09:41:57.160869 139957071652672 torch/_dynamo/guards.py:2186] [1/0_1] [__guards] | | +- GuardManager: source=___from_numpy(L['m']), accessed_by=PythonLambdaGuardAccessor
V0717 09:41:57.160869 139957071652672 torch/_dynamo/guards.py:2186] [1/0_1] [__guards] | | | +- TENSOR_MATCH: check_tensor(___from_numpy(L['m']), Tensor, DispatchKeySet(CPU, BackendSelect, ADInplaceOrView, AutogradCPU), torch.float32, device=None, requires_grad=False, size=[], stride=[])  # return a if m.is_integer() else 2 * a  # code/test/python/guards.py:16 in func2
V0717 09:41:57.160869 139957071652672 torch/_dynamo/guards.py:2186] [1/0_1] [__guards] | | | +- NO_HASATTR: hasattr(___from_numpy(L['m']), '_dynamo_dynamic_indices') == False  # return a if m.is_integer() else 2 * a  # code/test/python/guards.py:16 in func2
V0717 09:41:57.160869 139957071652672 torch/_dynamo/guards.py:2186] [1/0_1] [__guards] 

Versions

PyTorch version: 2.5.0a0+git0424c83
Is debug build: False
CUDA used to build PyTorch: 12.2
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.2 LTS (x86_64)
GCC version: (conda-forge gcc 12.3.0-7) 12.3.0
Clang version: Could not collect
CMake version: version 3.29.3
Libc version: glibc-2.35

Python version: 3.8.19 | packaged by conda-forge | (default, Mar 20 2024, 12:47:35)  [GCC 12.3.0] (64-bit runtime)
Python platform: Linux-5.15.0-97-generic-x86_64-with-glibc2.10
Is CUDA available: True
CUDA runtime version: 12.2.140
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA GeForce RTX 2060
GPU 1: NVIDIA GeForce RTX 2060

Nvidia driver version: 545.23.08
cuDNN version: Probably one of the following:
/usr/local/cuda-11.7.1/targets/x86_64-linux/lib/libcudnn.so.8
/usr/local/cuda-11.7.1/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8
/usr/local/cuda-11.7.1/targets/x86_64-linux/lib/libcudnn_adv_train.so.8
/usr/local/cuda-11.7.1/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8
/usr/local/cuda-11.7.1/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8
/usr/local/cuda-11.7.1/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8
/usr/local/cuda-11.7.1/targets/x86_64-linux/lib/libcudnn_ops_train.so.8
/usr/local/cuda-12.3.2/targets/x86_64-linux/lib/libcudnn.so.9
/usr/local/cuda-12.3.2/targets/x86_64-linux/lib/libcudnn_adv.so.9
/usr/local/cuda-12.3.2/targets/x86_64-linux/lib/libcudnn_cnn.so.9
/usr/local/cuda-12.3.2/targets/x86_64-linux/lib/libcudnn_engines_precompiled.so.9
/usr/local/cuda-12.3.2/targets/x86_64-linux/lib/libcudnn_engines_runtime_compiled.so.9
/usr/local/cuda-12.3.2/targets/x86_64-linux/lib/libcudnn_graph.so.9
/usr/local/cuda-12.3.2/targets/x86_64-linux/lib/libcudnn_heuristic.so.9
/usr/local/cuda-12.3.2/targets/x86_64-linux/lib/libcudnn_ops.so.9
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: False

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      43 bits physical, 48 bits virtual
Byte Order:                         Little Endian
CPU(s):                             64
On-line CPU(s) list:                0-63
Vendor ID:                          AuthenticAMD
Model name:                         AMD Ryzen Threadripper 3970X 32-Core Processor
CPU family:                         23
Model:                              49
Thread(s) per core:                 2
Core(s) per socket:                 32
Socket(s):                          1
Stepping:                           0
Frequency boost:                    enabled
CPU max MHz:                        3700.0000
CPU min MHz:                        2200.0000
BogoMIPS:                           7400.38
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip rdpid overflow_recov succor smca sme sev sev_es
Virtualization:                     AMD-V
L1d cache:                          1 MiB (32 instances)
L1i cache:                          1 MiB (32 instances)
L2 cache:                           16 MiB (32 instances)
L3 cache:                           128 MiB (8 instances)
NUMA node(s):                       1
NUMA node0 CPU(s):                  0-63
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Mitigation; untrained return thunk; SMT enabled with STIBP protection
Vulnerability Spec rstack overflow: Mitigation; safe RET
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Retpolines, IBPB conditional, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected

Versions of relevant libraries:
[pip3] bert_pytorch==0.0.1a4
[pip3] clip-anytorch==2.6.0
[pip3] CoCa-pytorch==0.1.0
[pip3] dalle2-pytorch==1.14.2
[pip3] ema-pytorch==0.4.8
[pip3] flake8==6.1.0
[pip3] flake8-bugbear==23.3.23
[pip3] flake8-comprehensions==3.12.0
[pip3] flake8-executable==2.1.3
[pip3] flake8-logging-format==0.9.0
[pip3] flake8-pyi==23.3.1
[pip3] flake8-simplify==0.19.3
[pip3] functorch==1.14.0a0+b71aa0b
[pip3] mypy==1.10.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.24.3
[pip3] onnx==1.16.1
[pip3] open-clip-torch==2.24.0
[pip3] optree==0.11.0
[pip3] pytorch-labs-segment-anything-fast==0.2
[pip3] pytorch-triton==3.0.0+dedb7bdf33
[pip3] pytorch-warmup==0.1.1
[pip3] rotary-embedding-torch==0.3.3
[pip3] torch==2.5.0a0+git8a28a9e
[pip3] torch-fidelity==0.3.0
[pip3] torch_geometric==2.4.0
[pip3] torchao==0.2.0
[pip3] torchaudio==2.4.0a0+b829e93
[pip3] torchdata==0.7.1a0+958eeb0
[pip3] torchmetrics==1.4.0.post0
[pip3] torchmultimodal==0.1.0b0
[pip3] torchtext==0.17.0a0+09e2690
[pip3] torchvision==0.20.0a0+bf01bab
[pip3] vector_quantize_pytorch==1.14.8
[conda] bert-pytorch              0.0.1a4                   dev_0    <develop>
[conda] clip-anytorch             2.6.0                    pypi_0    pypi
[conda] coca-pytorch              0.1.0                    pypi_0    pypi
[conda] dalle2-pytorch            1.14.2                   pypi_0    pypi
[conda] ema-pytorch               0.4.8                    pypi_0    pypi
[conda] functorch                 1.14.0a0+b71aa0b          pypi_0    pypi
[conda] libmagma                  2.7.2                h173bb3b_2    conda-forge
[conda] libmagma_sparse           2.7.2                h173bb3b_3    conda-forge
[conda] magma                     2.7.2                h51420fd_3    conda-forge
[conda] mkl                       2024.1.0           ha957f24_693    conda-forge
[conda] mkl-include               2024.1.0           ha957f24_693    conda-forge
[conda] numpy                     1.24.3                   pypi_0    pypi
[conda] open-clip-torch           2.24.0                   pypi_0    pypi
[conda] optree                    0.11.0           py38h7f3f72f_0    conda-forge
[conda] pytorch-labs-segment-anything-fast 0.2                      pypi_0    pypi
[conda] pytorch-triton            3.0.0+dedb7bdf33          pypi_0    pypi
[conda] pytorch-warmup            0.1.1                    pypi_0    pypi
[conda] rotary-embedding-torch    0.3.3                    pypi_0    pypi
[conda] torch                     2.5.0a0+git8a28a9e           dev_0    <develop>
[conda] torch-fidelity            0.3.0                    pypi_0    pypi
[conda] torch-geometric           2.4.0                    pypi_0    pypi
[conda] torchao                   0.2.0                    pypi_0    pypi
[conda] torchaudio                2.4.0a0+b829e93          pypi_0    pypi
[conda] torchdata                 0.7.1a0+958eeb0          pypi_0    pypi
[conda] torchmetrics              1.4.0.post0              pypi_0    pypi
[conda] torchmultimodal           0.1.0b0                  pypi_0    pypi
[conda] torchtext                 0.17.0a0+09e2690          pypi_0    pypi
[conda] torchvision               0.20.0a0+bf01bab          pypi_0    pypi
[conda] vector-quantize-pytorch   1.14.8                   pypi_0    pypi

cc @mruberry @rgommers @ezyang @anijain2305 @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames

ezyang commented 1 month ago

cc @lezcano

lezcano commented 1 month ago

This is quite reasonable really, given that NumPy scalars are treated as tensors inside dynamo. Also, given how narrow is the scope of this issue, I vote nofix.