Using hasattr on tensor in torch.compile causes graph breaks, while getattr is not.
Error logs
Traceback (most recent call last):
File "/home/tbohutyn/workspace/testy/cpu_var_tracker.py", line 19, in
main()
File "/home/tbohutyn/workspace/testy/cpu_var_tracker.py", line 15, in main
fn(t1)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
return fn(*args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 921, in catch_errors
return callback(frame, cache_entry, hooks, frame_state, skip=1)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 400, in _convert_frame_assert
return _compile(
File "/usr/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, *kwds)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 676, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 262, in time_wrapper
r = func(args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 535, in compile_inner
out_code = transform_code_object(code, transform)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1036, in transform_code_object
transformations(instructions, code_options)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 165, in _fn
return fn(*args, *kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 500, in transform
tracer.run()
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2099, in run
super().run()
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 815, in run
and self.step()
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 778, in step
getattr(self, inst.opname)(inst)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 494, in wrapper
return inner_fn(self, inst)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1169, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 679, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builtin.py", line 935, in call_function
return handler(tx, args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builtin.py", line 700, in
return lambda tx, args, kwargs: obj.call_function(
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builtin.py", line 935, in call_function
return handler(tx, args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builtin.py", line 810, in builtin_dipatch
rv = handler(tx, args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builtin.py", line 739, in call_self_handler
result = self_handler(tx, args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builtin.py", line 1394, in call_hasattr
return obj.call_hasattr(tx, name)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/base.py", line 308, in call_hasattr
unimplemented(f"hasattr {self.class.name} {name}")
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/exc.py", line 190, in unimplemented
raise Unsupported(msg)
torch._dynamo.exc.Unsupported: hasattr TensorVariable attr
from user code:
File "/home/tbohutyn/workspace/testy/cpu_var_tracker.py", line 6, in fn
if hasattr(x, "attr"):
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
Minified repro
import torch
# using fullgraph to raise an exception in place of graph breaks
@torch.compile(backend='inductor', fullgraph=True)
def fn(x):
if hasattr(x, "attr"):
return x + 1
else:
return x - 1
def main():
t1 = torch.tensor([6.])
t1.attr = False
fn(t1)
if __name__ == "__main__":
main()
Versions
Collecting environment information...
PyTorch version: 2.3.1a0+gitbede712
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.29.6
Libc version: glibc-2.35
Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-107-generic-x86_64-with-glibc2.35
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 43 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 12
On-line CPU(s) list: 0-11
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) Gold 6132 CPU @ 2.60GHz
CPU family: 6
Model: 85
Thread(s) per core: 1
Core(s) per socket: 6
Socket(s): 2
Stepping: 0
BogoMIPS: 5187.81
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon nopl xtopology tsc_reliable nonstop_tsc cpuid tsc_known_freq pni pclmulqdq vmx ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti ssbd ibrs ibpb stibp tpr_shadow vnmi ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 invpcid avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xsaves arat pku ospke md_clear flush_l1d arch_capabilities
Virtualization: VT-x
Hypervisor vendor: VMware
Virtualization type: full
L1d cache: 384 KiB (12 instances)
L1i cache: 384 KiB (12 instances)
L2 cache: 12 MiB (12 instances)
L3 cache: 38.5 MiB (2 instances)
NUMA node(s): 1
NUMA node0 CPU(s): 0-11
Vulnerability Gather data sampling: Unknown: Dependent on hypervisor status
Vulnerability Itlb multihit: KVM: Mitigation: VMX disabled
Vulnerability L1tf: Mitigation; PTE Inversion; VMX flush not necessary, SMT disabled
Vulnerability Mds: Mitigation; Clear CPU buffers; SMT Host state unknown
Vulnerability Meltdown: Mitigation; PTI
Vulnerability Mmio stale data: Mitigation; Clear CPU buffers; SMT Host state unknown
Vulnerability Retbleed: Mitigation; IBRS
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; IBRS; IBPB conditional; STIBP disabled; RSB filling; PBRSB-eIBRS Not affected; BHI Syscall hardening, KVM SW loop
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] numpy==1.23.5
[pip3] pytorch-lightning==2.3.1
[pip3] torch==2.3.1a0+gitbede712
[pip3] torch_tb_profiler==0.4.0
[pip3] torchaudio==2.3.0+952ea74
[pip3] torchdata==0.7.1+5e6f7b7
[pip3] torchmetrics==1.4.0.post0
[pip3] torchtext==0.18.0a0+9bed85d
[pip3] torchvision==0.18.1a0+fe70bc8
[conda] Could not collect
🐛 Describe the bug
Using hasattr on tensor in torch.compile causes graph breaks, while getattr is not.
Error logs
Traceback (most recent call last): File "/home/tbohutyn/workspace/testy/cpu_var_tracker.py", line 19, in
main()
File "/home/tbohutyn/workspace/testy/cpu_var_tracker.py", line 15, in main
fn(t1)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
return fn(*args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 921, in catch_errors
return callback(frame, cache_entry, hooks, frame_state, skip=1)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 400, in _convert_frame_assert
return _compile(
File "/usr/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, *kwds)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 676, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 262, in time_wrapper
r = func(args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 535, in compile_inner
out_code = transform_code_object(code, transform)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1036, in transform_code_object
transformations(instructions, code_options)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 165, in _fn
return fn(*args, *kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 500, in transform
tracer.run()
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2099, in run
super().run()
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 815, in run
and self.step()
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 778, in step
getattr(self, inst.opname)(inst)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 494, in wrapper
return inner_fn(self, inst)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1169, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 679, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builtin.py", line 935, in call_function
return handler(tx, args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builtin.py", line 700, in
return lambda tx, args, kwargs: obj.call_function(
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builtin.py", line 935, in call_function
return handler(tx, args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builtin.py", line 810, in builtin_dipatch
rv = handler(tx, args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builtin.py", line 739, in call_self_handler
result = self_handler(tx, args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builtin.py", line 1394, in call_hasattr
return obj.call_hasattr(tx, name)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/base.py", line 308, in call_hasattr
unimplemented(f"hasattr {self.class.name} {name}")
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/exc.py", line 190, in unimplemented
raise Unsupported(msg)
torch._dynamo.exc.Unsupported: hasattr TensorVariable attr
from user code: File "/home/tbohutyn/workspace/testy/cpu_var_tracker.py", line 6, in fn if hasattr(x, "attr"):
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting: import torch._dynamo torch._dynamo.config.suppress_errors = True
Minified repro
Versions
Collecting environment information... PyTorch version: 2.3.1a0+gitbede712 Is debug build: False CUDA used to build PyTorch: None ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.4 LTS (x86_64) GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 Clang version: Could not collect CMake version: version 3.29.6 Libc version: glibc-2.35
Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime) Python platform: Linux-5.15.0-107-generic-x86_64-with-glibc2.35 Is CUDA available: False CUDA runtime version: No CUDA CUDA_MODULE_LOADING set to: N/A GPU models and configuration: No CUDA Nvidia driver version: No CUDA cuDNN version: No CUDA HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True
CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 43 bits physical, 48 bits virtual Byte Order: Little Endian CPU(s): 12 On-line CPU(s) list: 0-11 Vendor ID: GenuineIntel Model name: Intel(R) Xeon(R) Gold 6132 CPU @ 2.60GHz CPU family: 6 Model: 85 Thread(s) per core: 1 Core(s) per socket: 6 Socket(s): 2 Stepping: 0 BogoMIPS: 5187.81 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon nopl xtopology tsc_reliable nonstop_tsc cpuid tsc_known_freq pni pclmulqdq vmx ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti ssbd ibrs ibpb stibp tpr_shadow vnmi ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 invpcid avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xsaves arat pku ospke md_clear flush_l1d arch_capabilities Virtualization: VT-x Hypervisor vendor: VMware Virtualization type: full L1d cache: 384 KiB (12 instances) L1i cache: 384 KiB (12 instances) L2 cache: 12 MiB (12 instances) L3 cache: 38.5 MiB (2 instances) NUMA node(s): 1 NUMA node0 CPU(s): 0-11 Vulnerability Gather data sampling: Unknown: Dependent on hypervisor status Vulnerability Itlb multihit: KVM: Mitigation: VMX disabled Vulnerability L1tf: Mitigation; PTE Inversion; VMX flush not necessary, SMT disabled Vulnerability Mds: Mitigation; Clear CPU buffers; SMT Host state unknown Vulnerability Meltdown: Mitigation; PTI Vulnerability Mmio stale data: Mitigation; Clear CPU buffers; SMT Host state unknown Vulnerability Retbleed: Mitigation; IBRS Vulnerability Spec rstack overflow: Not affected Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; IBRS; IBPB conditional; STIBP disabled; RSB filling; PBRSB-eIBRS Not affected; BHI Syscall hardening, KVM SW loop Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected
Versions of relevant libraries: [pip3] numpy==1.23.5 [pip3] pytorch-lightning==2.3.1 [pip3] torch==2.3.1a0+gitbede712 [pip3] torch_tb_profiler==0.4.0 [pip3] torchaudio==2.3.0+952ea74 [pip3] torchdata==0.7.1+5e6f7b7 [pip3] torchmetrics==1.4.0.post0 [pip3] torchtext==0.18.0a0+9bed85d [pip3] torchvision==0.18.1a0+fe70bc8 [conda] Could not collect
cc @ezyang @anijain2305 @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng