Open clessig opened 4 months ago
A repro would help us to debug, but cc @zou3519 @bdhirsh in case they have some ideas about the error msg "RuntimeError: The grad inputs should be same tensor subclass type as forward output".
Here's another failure case with nested_tensor and compile:
0: torch._dynamo.exc.BackendCompilerFailed: backend='compile_fn' raised: 0: RuntimeError: shape '[192, -1, 16, 96]' is invalid for input of size 1536*s13 0: 0: While executing %reshape : [num_users=1] = call_method[target=reshape](args = (%l__self___proj_heads_v, [192, -1, 16, 96]), kwargs = {}) 0: Original traceback: 0: File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-test/ai-obs-experimental-transformer/obslearn/model/attention.py", line 130, in forward 0: vs = self.proj_heads_v( x_kv).reshape(s).transpose( -3, -2) 0: 0: 0: While executing %submod_1 : [num_users=1] = call_module[target=submod_1](args = (%getitem, %zf12), kwargs = {})
"RuntimeError: The grad inputs should be same tensor subclass type as forward output".
We're tentatively hoping to be able to fix this as part of https://github.com/pytorch/pytorch/issues/91469, although it will take a while.
@clessig this error effectively shows up when you're compiling some function f
:
@torch.compile
def f(x):
out = ...
return out
Where out
is a tensor subclass (e.g. NestedTensor), but when you call .backward()
the corresponding gradient that flows into the backward graph is not a NestedTensor.
So far I've found that the issue is hit relatively rarely and can sometimes be possible to work around. So a repro would be helpful (same with the other issues)
@bdhirsh Thanks for the suggestion! I will try to develop a work around then.
I will also try to come up with some simple repo cases at the weekend; unfortunately on a bit tight schedule at the moment
@zou3519 @bdhirsh , here is a repo case for yet a different problem that I ended up with when I tried to do a repo case for yet a different problem :) (Code runs without error when the compile statement is commented out.)
import torch
import code
@torch.compile
def mha() :
bs = 2
nt_len = 2
q = torch.nested.nested_tensor( [torch.rand( (bs, 8*16), dtype=torch.float16, device='cuda') for _ in range(nt_len) ], layout=torch.jagged)
k = torch.nested.nested_tensor( [torch.rand( (bs, 8*16), dtype=torch.float16, device='cuda') for _ in range(nt_len) ], layout=torch.jagged)
v = torch.nested.nested_tensor( [torch.rand( (bs, 8*16), dtype=torch.float16, device='cuda') for _ in range(nt_len) ], layout=torch.jagged)
q = q.reshape([bs,-1,8,16]).transpose( 2, 1)
k = k.reshape([bs,-1,8,16]).transpose( 2, 1)
v = v.reshape([bs,-1,8,16]).transpose( 2, 1)
att = torch.nn.functional.scaled_dot_product_attention
# with torch.nn.attention.sdpa_kernel( torch.nn.attention.SDPBackend.FLASH_ATTENTION) :
out = att( q, k, v).transpose( 2, 1)
return out
out = mha()
print( out.shape)
The error I get is:
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-test/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 590, in create_aot_dispatcher_function
fw_metadata = run_functionalized_fw_and_collect_metadata(
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-test/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 570, in inner
dynamic_dims = {
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-test/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 571, in <setcomp>
i for i, s in enumerate(o.shape) if not is_concrete_int(s)
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-test/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py", line 228, in is_concrete_int
if isinstance(a.node.expr, sympy.core.numbers.Integer):
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
AttributeError: 'torch._C._SymNode' object has no attribute 'expr'
What I tried to condense to a repo case is another issue with reshape in conceptual the same situation as above (i.e. reshape into heads and dim_embed after head projection for MHA). There I get:
3: torch._dynamo.exc.BackendCompilerFailed: backend='compile_fn' raised:
3: RuntimeError: shape '[192, -1, 16, 256]' is invalid for input of size 4096*s13
3:
3: While executing %reshape : [num_users=1] = call_method[target=reshape](args = (%l__self___proj_heads_v, [192, -1, 16, 256]), kwargs = {})3: Original traceback:
3: File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-test/ai-obs-experimental-transformer/obslearn/model/attention.py", line 130, in forward
3: vs = self.proj_heads_v( x_kv).reshape(s).transpose( -3, -2)
3:
3:
3: While executing %submod_1 : [num_users=1] = call_module[target=submod_1](args = (%getitem, %zf12), kwargs = {})
3: Original traceback:
3: None
3:
That's an error that I only get when I try to run my model as a batch job with DDP but before I wrap the model with torch.nn.parallel.DistributedDataParallel()
. I couldn't find any difference on my runtime in the two cases.
torch.compile gives me a significant speedup (30%-40%) when I run interactively so would very much like to have it also in batch mode. Happy to help!
Thanks!
@clessig the error you're seeing here is related to constructing an NJT in a compiled graph. This is broken right now (see #126472 and related issues).
AttributeError: 'torch._C._SymNode' object has no attribute 'expr'
@soulitzer and I have been very actively working to get a proper fix in for these construction issues (more background and a probably working fix can be found in #130505).
In the meantime, constructing NJTs outside of the compiled function and passing them in should function as a workaround.
🐛 Describe the bug
I use nested_tensor in a bit more complex code. But compiling fails on different parts. Below are the stack traces. I can try to construct some repo cases at the weekend. But I thought I post it already in case someone has an idea what is going wrong based on the stack trace.
My code is automatic mixed precision, which should be relevant for the first error. Is this supposed to work?
Error logs
Error stack trace 1:
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/obslearn/train/train.py", line 127, in
train()
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/obslearn/train/train.py", line 120, in train
trainer.run( cf)
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/obslearn/train/trainer.py", line 338, in run
self.train( epoch)
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/obslearn/train/trainer.py", line 470, in train
self.grad_scaler.scale(loss).backward()
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_tensor.py", line 522, in backward
torch.autograd.backward(
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/autograd/init.py", line 288, in backward
_engine_run_backward(
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/autograd/graph.py", line 799, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/autograd/function.py", line 306, in apply
return user_fn(self, *args)
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1820, in backward
raise RuntimeError(
RuntimeError: The grad inputs should be same tensor subclass type as forward output
Error stack trace 2:
Traceback (most recent call last): File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/obslearn/train/train.py", line 127, in
train()
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/obslearn/train/train.py", line 120, in train
trainer.run( cf)
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/obslearn/train/trainer.py", line 334, in run
self.validate( -1, num_batches=10)
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/obslearn/train/trainer.py", line 529, in validate
preds = self.ddp_model( source_tokens_cells, source_tokens_lens,
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1716, in _wrapped_call_impl
return self._call_impl(*args, kwargs)
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1727, in _call_impl
return forward_call(*args, *kwargs)
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/obslearn/model/obs_model.py", line 291, in forward
preds_all = self.predict_compiled( tokens, tcs, tcs_lens)
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 435, in _fn
return fn(args, kwargs)
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1121, in call
return self._torchdynamo_orig_callable(
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 948, in call
result = self._inner_convert(
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 472, in call
return _compile(
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_utils_internal.py", line 85, in wrapper_function
return StrobelightCompileTimeProfiler.profile_compile_time(
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time
return func(*args, kwargs)
File "/usr/local/apps/python3/3.10.10-01/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, *kwds)
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 817, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 234, in time_wrapper
r = func(args, kwargs)
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 636, in compile_inner
out_code = transform_code_object(code, transform)
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1270, in transform_code_object
transformations(instructions, code_options)
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 178, in _fn
return fn(*args, *kwargs)
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 582, in transform
tracer.run()
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2489, in run
super().run()
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 904, in run
while self.step():
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 816, in step
self.dispatch_table[inst.opcode](self, inst)
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 510, in wrapper
return inner_fn(self, inst)
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1480, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 754, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/variables/nn_module.py", line 433, in call_function
return tx.inline_user_function_return(
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 760, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2704, in inline_call
return cls.inlinecall(parent, func, args, kwargs)
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2820, in inlinecall
tracer.run()
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 904, in run
while self.step():
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 816, in step
self.dispatch_table[inst.opcode](self, inst)
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 510, in wrapper
return inner_fn(self, inst)
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1534, in CALL_FUNCTION_EX
self.call_function(fn, argsvars.items, kwargsvars)
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 754, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 360, in call_function
return super().call_function(tx, args, kwargs)
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 302, in call_function
return super().call_function(tx, args, kwargs)
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 102, in call_function
return tx.inline_user_function_return(self, [self.self_args(), args], kwargs)
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 760, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2704, in inline_call
return cls.inlinecall(parent, func, args, kwargs)
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2820, in inlinecall
tracer.run()
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 904, in run
while self.step():
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 816, in step
self.dispatch_table[inst.opcode](self, inst)
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1598, in LOAD_ATTR
self._load_attr(inst)
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1588, in _load_attr
result = BuiltinVariable(getattr).call_function(
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 963, in call_function
return handler(tx, args, kwargs)
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 847, in builtin_dipatch
rv = fn(tx, args, kwargs)
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 765, in call_self_handler
result = self_handler(tx, args, **kwargs)
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 1607, in call_getattr
return obj.var_getattr(tx, name)
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/variables/tensor.py", line 347, in var_getattr
result = handler(tx) if handler is not None else None
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/variables/tensor.py", line 304, in method_attr_shape
sizes = [variables.ConstantVariable.create(x) for x in self.size]
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/variables/tensor.py", line 304, in
sizes = [variables.ConstantVariable.create(x) for x in self.size]
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/_dynamo/variables/constant.py", line 39, in create
assert not isinstance(value, disallowed_type), reason
AssertionError: SymInts must use SymNodeVariable. If the underlying value is static, we will create a ConstantVariable and specialize.
from user code: File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/obslearn/model/obs_model.py", line 379, in predict tc_tokens = block( tc_tokens, tokens_i) File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1727, in _call_impl return forward_call(*args, **kwargs) File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell-forecast/ai-obs-experimental-transformer/obslearn/model/attention.py", line 128, in forward s = [ x_kv.shape[0], -1, self.num_heads, self.dim_head_proj ]
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting: import torch._dynamo torch._dynamo.config.suppress_errors = True
Minified repro
No response
Versions
Collecting environment information... PyTorch version: 2.5.0.dev20240710+cu124 Is debug build: False CUDA used to build PyTorch: 12.4 ROCM used to build PyTorch: N/A
OS: Red Hat Enterprise Linux release 8.8 (Ootpa) (x86_64) GCC version: (GCC) 8.5.0 20210514 (Red Hat 8.5.0-18) Clang version: 15.0.7 (Red Hat 15.0.7-1.module+el8.8.0+17939+b58878af) CMake version: version 3.30.0 Libc version: glibc-2.28
Python version: 3.10.10 (main, Feb 9 2023, 14:42:48) [GCC 8.5.0 20210514 (Red Hat 8.5.0-10)] (64-bit runtime) Python platform: Linux-4.18.0-477.43.1.el8_8.x86_64-x86_64-with-glibc2.28 Is CUDA available: True CUDA runtime version: Could not collect CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: NVIDIA A100-SXM4-40GB Nvidia driver version: 550.54.14 cuDNN version: Probably one of the following: /usr/lib64/libcudnn.so.8.9.7 /usr/lib64/libcudnn_adv_infer.so.8.9.7 /usr/lib64/libcudnn_adv_train.so.8.9.7 /usr/lib64/libcudnn_cnn_infer.so.8.9.7 /usr/lib64/libcudnn_cnn_train.so.8.9.7 /usr/lib64/libcudnn_ops_infer.so.8.9.7 /usr/lib64/libcudnn_ops_train.so.8.9.7 HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True
CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Byte Order: Little Endian CPU(s): 256 On-line CPU(s) list: 0-255 Thread(s) per core: 2 Core(s) per socket: 64 Socket(s): 2 NUMA node(s): 4 Vendor ID: AuthenticAMD CPU family: 23 Model: 49 Model name: AMD EPYC 7742 64-Core Processor Stepping: 0 CPU MHz: 2250.000 CPU max MHz: 2250.0000 CPU min MHz: 1500.0000 BogoMIPS: 4500.27 Virtualization: AMD-V L1d cache: 32K L1i cache: 32K L2 cache: 512K L3 cache: 16384K NUMA node0 CPU(s): 0-31,128-159 NUMA node1 CPU(s): 32-63,160-191 NUMA node2 CPU(s): 64-95,192-223 NUMA node3 CPU(s): 96-127,224-255 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip rdpid overflow_recov succor smca sme sev sev_es
Versions of relevant libraries: [pip3] flake8==7.1.0 [pip3] mypy-extensions==1.0.0 [pip3] numpy==1.26.4 [pip3] pytorch-triton==3.0.0+dedb7bdf33 [pip3] torch==2.5.0.dev20240710+cu124 [pip3] triton==2.3.1 [conda] Could not collect
cc @ezyang @anijain2305 @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames