pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
81.63k stars 21.9k forks source link

torch.compile fails for complex nested_tensor code #130825

Open clessig opened 1 month ago

clessig commented 1 month ago

🐛 Describe the bug

I use nested_tensor in a bit more complex code. But compiling fails on different parts. Below are the stack traces. I can try to construct some repo cases at the weekend. But I thought I post it already in case someone has an idea what is going wrong based on the stack trace.

My code is automatic mixed precision, which should be relevant for the first error. Is this supposed to work?

Error logs

Error stack trace 1:

File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/obslearn/train/train.py", line 127, in train() File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/obslearn/train/train.py", line 120, in train trainer.run( cf) File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/obslearn/train/trainer.py", line 338, in run self.train( epoch) File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/obslearn/train/trainer.py", line 470, in train self.grad_scaler.scale(loss).backward() File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_tensor.py", line 522, in backward torch.autograd.backward( File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/autograd/init.py", line 288, in backward _engine_run_backward( File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/autograd/graph.py", line 799, in _engine_run_backward return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/autograd/function.py", line 306, in apply return user_fn(self, *args) File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1820, in backward raise RuntimeError( RuntimeError: The grad inputs should be same tensor subclass type as forward output

/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py(1820)backward() -> raise RuntimeError(


Error stack trace 2:

Traceback (most recent call last): File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/obslearn/train/train.py", line 127, in train() File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/obslearn/train/train.py", line 120, in train trainer.run( cf) File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/obslearn/train/trainer.py", line 334, in run self.validate( -1, num_batches=10) File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/obslearn/train/trainer.py", line 529, in validate preds = self.ddp_model( source_tokens_cells, source_tokens_lens, File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1716, in _wrapped_call_impl return self._call_impl(*args, kwargs) File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1727, in _call_impl return forward_call(*args, *kwargs) File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/obslearn/model/obs_model.py", line 291, in forward preds_all = self.predict_compiled( tokens, tcs, tcs_lens) File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 435, in _fn return fn(args, kwargs) File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1121, in call return self._torchdynamo_orig_callable( File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 948, in call result = self._inner_convert( File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 472, in call return _compile( File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_utils_internal.py", line 85, in wrapper_function return StrobelightCompileTimeProfiler.profile_compile_time( File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time return func(*args, kwargs) File "/usr/local/apps/python3/3.10.10-01/lib/python3.10/contextlib.py", line 79, in inner return func(*args, *kwds) File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 817, in _compile guarded_code = compile_inner(code, one_graph, hooks, transform) File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 234, in time_wrapper r = func(args, kwargs) File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 636, in compile_inner out_code = transform_code_object(code, transform) File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1270, in transform_code_object transformations(instructions, code_options) File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 178, in _fn return fn(*args, *kwargs) File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 582, in transform tracer.run() File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2489, in run super().run() File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 904, in run while self.step(): File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 816, in step self.dispatch_table[inst.opcode](self, inst) File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 510, in wrapper return inner_fn(self, inst) File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1480, in CALL_FUNCTION self.call_function(fn, args, {}) File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 754, in call_function self.push(fn.call_function(self, args, kwargs)) File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/variables/nn_module.py", line 433, in call_function return tx.inline_user_function_return( File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 760, in inline_user_function_return return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2704, in inline_call return cls.inlinecall(parent, func, args, kwargs) File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2820, in inlinecall tracer.run() File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 904, in run while self.step(): File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 816, in step self.dispatch_table[inst.opcode](self, inst) File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 510, in wrapper return inner_fn(self, inst) File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1534, in CALL_FUNCTION_EX self.call_function(fn, argsvars.items, kwargsvars) File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 754, in call_function self.push(fn.call_function(self, args, kwargs)) File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 360, in call_function return super().call_function(tx, args, kwargs) File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 302, in call_function return super().call_function(tx, args, kwargs) File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 102, in call_function return tx.inline_user_function_return(self, [self.self_args(), args], kwargs) File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 760, in inline_user_function_return return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2704, in inline_call return cls.inlinecall(parent, func, args, kwargs) File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2820, in inlinecall tracer.run() File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 904, in run while self.step(): File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 816, in step self.dispatch_table[inst.opcode](self, inst) File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1598, in LOAD_ATTR self._load_attr(inst) File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1588, in _load_attr result = BuiltinVariable(getattr).call_function( File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 963, in call_function return handler(tx, args, kwargs) File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 847, in builtin_dipatch rv = fn(tx, args, kwargs) File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 765, in call_self_handler result = self_handler(tx, args, **kwargs) File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 1607, in call_getattr return obj.var_getattr(tx, name) File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/variables/tensor.py", line 347, in var_getattr result = handler(tx) if handler is not None else None File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/variables/tensor.py", line 304, in method_attr_shape sizes = [variables.ConstantVariable.create(x) for x in self.size] File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/variables/tensor.py", line 304, in sizes = [variables.ConstantVariable.create(x) for x in self.size] File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/variables/constant.py", line 39, in create assert not isinstance(value, disallowed_type), reason AssertionError: SymInts must use SymNodeVariable. If the underlying value is static, we will create a ConstantVariable and specialize.

from user code: File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/obslearn/model/obs_model.py", line 379, in predict tc_tokens = block( tc_tokens, tokens_i) File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1727, in _call_impl return forward_call(*args, **kwargs) File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/obslearn/model/attention.py", line 128, in forward s = [ x_kv.shape[0], -1, self.num_heads, self.dim_head_proj ]

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

You can suppress this exception and fall back to eager by setting: import torch._dynamo torch._dynamo.config.suppress_errors = True

/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/variables/constant.py(39)create() -> assert not isinstance(value, disallowed_type), reason

Minified repro

No response

Versions

Collecting environment information... PyTorch version: 2.5.0.dev20240710+cu124 Is debug build: False CUDA used to build PyTorch: 12.4 ROCM used to build PyTorch: N/A

OS: Red Hat Enterprise Linux release 8.8 (Ootpa) (x86_64) GCC version: (GCC) 8.5.0 20210514 (Red Hat 8.5.0-18) Clang version: 15.0.7 (Red Hat 15.0.7-1.module+el8.8.0+17939+b58878af) CMake version: version 3.30.0 Libc version: glibc-2.28

Python version: 3.10.10 (main, Feb 9 2023, 14:42:48) [GCC 8.5.0 20210514 (Red Hat 8.5.0-10)] (64-bit runtime) Python platform: Linux-4.18.0-477.43.1.el8_8.x86_64-x86_64-with-glibc2.28 Is CUDA available: True CUDA runtime version: Could not collect CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: NVIDIA A100-SXM4-40GB Nvidia driver version: 550.54.14 cuDNN version: Probably one of the following: /usr/lib64/libcudnn.so.8.9.7 /usr/lib64/libcudnn_adv_infer.so.8.9.7 /usr/lib64/libcudnn_adv_train.so.8.9.7 /usr/lib64/libcudnn_cnn_infer.so.8.9.7 /usr/lib64/libcudnn_cnn_train.so.8.9.7 /usr/lib64/libcudnn_ops_infer.so.8.9.7 /usr/lib64/libcudnn_ops_train.so.8.9.7 HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Byte Order: Little Endian CPU(s): 256 On-line CPU(s) list: 0-255 Thread(s) per core: 2 Core(s) per socket: 64 Socket(s): 2 NUMA node(s): 4 Vendor ID: AuthenticAMD CPU family: 23 Model: 49 Model name: AMD EPYC 7742 64-Core Processor Stepping: 0 CPU MHz: 2250.000 CPU max MHz: 2250.0000 CPU min MHz: 1500.0000 BogoMIPS: 4500.27 Virtualization: AMD-V L1d cache: 32K L1i cache: 32K L2 cache: 512K L3 cache: 16384K NUMA node0 CPU(s): 0-31,128-159 NUMA node1 CPU(s): 32-63,160-191 NUMA node2 CPU(s): 64-95,192-223 NUMA node3 CPU(s): 96-127,224-255 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip rdpid overflow_recov succor smca sme sev sev_es

Versions of relevant libraries: [pip3] flake8==7.1.0 [pip3] mypy-extensions==1.0.0 [pip3] numpy==1.26.4 [pip3] pytorch-triton==3.0.0+dedb7bdf33 [pip3] torch==2.5.0.dev20240710+cu124 [pip3] triton==2.3.1 [conda] Could not collect

cc @ezyang @anijain2305 @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames

desertfire commented 1 month ago

A repro would help us to debug, but cc @zou3519 @bdhirsh in case they have some ideas about the error msg "RuntimeError: The grad inputs should be same tensor subclass type as forward output".

clessig commented 1 month ago

Here's another failure case with nested_tensor and compile:

0: torch._dynamo.exc.BackendCompilerFailed: backend='compile_fn' raised: 0: RuntimeError: shape '[192, -1, 16, 96]' is invalid for input of size 1536*s13 0: 0: While executing %reshape : [num_users=1] = call_method[target=reshape](args = (%l__self___proj_heads_v, [192, -1, 16, 96]), kwargs = {}) 0: Original traceback: 0: File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-test/ai-obs-experimental-transformer/obslearn/model/attention.py", line 130, in forward 0: vs = self.proj_heads_v( x_kv).reshape(s).transpose( -3, -2) 0: 0: 0: While executing %submod_1 : [num_users=1] = call_module[target=submod_1](args = (%getitem, %zf12), kwargs = {})

bdhirsh commented 1 month ago

"RuntimeError: The grad inputs should be same tensor subclass type as forward output".

We're tentatively hoping to be able to fix this as part of https://github.com/pytorch/pytorch/issues/91469, although it will take a while.

@clessig this error effectively shows up when you're compiling some function f:

@torch.compile
def f(x):
    out = ...
    return out

Where out is a tensor subclass (e.g. NestedTensor), but when you call .backward() the corresponding gradient that flows into the backward graph is not a NestedTensor.

So far I've found that the issue is hit relatively rarely and can sometimes be possible to work around. So a repro would be helpful (same with the other issues)

clessig commented 1 month ago

@bdhirsh Thanks for the suggestion! I will try to develop a work around then.

I will also try to come up with some simple repo cases at the weekend; unfortunately on a bit tight schedule at the moment

clessig commented 1 month ago

@zou3519 @bdhirsh , here is a repo case for yet a different problem that I ended up with when I tried to do a repo case for yet a different problem :) (Code runs without error when the compile statement is commented out.)

import torch
import code

@torch.compile
def mha() :

  bs = 2
  nt_len = 2

  q = torch.nested.nested_tensor( [torch.rand( (bs, 8*16), dtype=torch.float16, device='cuda') for _ in range(nt_len) ], layout=torch.jagged)
  k = torch.nested.nested_tensor( [torch.rand( (bs, 8*16), dtype=torch.float16, device='cuda') for _ in range(nt_len) ], layout=torch.jagged)
  v = torch.nested.nested_tensor( [torch.rand( (bs, 8*16), dtype=torch.float16, device='cuda') for _ in range(nt_len) ], layout=torch.jagged)

  q = q.reshape([bs,-1,8,16]).transpose( 2, 1)
  k = k.reshape([bs,-1,8,16]).transpose( 2, 1)
  v = v.reshape([bs,-1,8,16]).transpose( 2, 1)

  att = torch.nn.functional.scaled_dot_product_attention
  # with torch.nn.attention.sdpa_kernel( torch.nn.attention.SDPBackend.FLASH_ATTENTION) :
  out = att( q, k, v).transpose( 2, 1)

  return out

out = mha()
print( out.shape)

The error I get is:

  File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-test/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 590, in create_aot_dispatcher_function
    fw_metadata = run_functionalized_fw_and_collect_metadata(
  File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-test/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 570, in inner
    dynamic_dims = {
  File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-test/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 571, in <setcomp>
    i for i, s in enumerate(o.shape) if not is_concrete_int(s)
  File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-test/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py", line 228, in is_concrete_int
    if isinstance(a.node.expr, sympy.core.numbers.Integer):
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
AttributeError: 'torch._C._SymNode' object has no attribute 'expr'

What I tried to condense to a repo case is another issue with reshape in conceptual the same situation as above (i.e. reshape into heads and dim_embed after head projection for MHA). There I get:

3: torch._dynamo.exc.BackendCompilerFailed: backend='compile_fn' raised:
3: RuntimeError: shape '[192, -1, 16, 256]' is invalid for input of size 4096*s13
3: 
3: While executing %reshape : [num_users=1] = call_method[target=reshape](args = (%l__self___proj_heads_v, [192, -1, 16, 256]), kwargs = {})3: Original traceback:
3:   File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-test/ai-obs-experimental-transformer/obslearn/model/attention.py", line 130, in forward
3:     vs = self.proj_heads_v( x_kv).reshape(s).transpose( -3, -2)
3: 
3: 
3: While executing %submod_1 : [num_users=1] = call_module[target=submod_1](args = (%getitem, %zf12), kwargs = {})
3: Original traceback:
3: None
3: 

That's an error that I only get when I try to run my model as a batch job with DDP but before I wrap the model with torch.nn.parallel.DistributedDataParallel(). I couldn't find any difference on my runtime in the two cases.

torch.compile gives me a significant speedup (30%-40%) when I run interactively so would very much like to have it also in batch mode. Happy to help!

Thanks!

jbschlosser commented 1 month ago

@clessig the error you're seeing here is related to constructing an NJT in a compiled graph. This is broken right now (see #126472 and related issues).

AttributeError: 'torch._C._SymNode' object has no attribute 'expr'

@soulitzer and I have been very actively working to get a proper fix in for these construction issues (more background and a probably working fix can be found in #130505).

In the meantime, constructing NJTs outside of the compiled function and passing them in should function as a workaround.