Open RaulPPelaez opened 4 months ago
Longer backtrace:
File "/home/ezyang/local/b/pytorch-env/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/home/ezyang/local/b/pytorch-env/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/data/users/ezyang/b/pytorch/torch/_dynamo/utils.py", line 218, in time_wrapper
r = func(*args, **kwargs)
File "/data/users/ezyang/b/pytorch/torch/_inductor/compile_fx.py", line 554, in compile_fx_inner
(not has_incompatible_cudagraph_ops(gm), "incompatible ops"),
File "/data/users/ezyang/b/pytorch/torch/_inductor/utils.py", line 671, in has_incompatible_cudagraph_ops
return get_first_incompatible_cudagraph_node(gm) is not None
File "/data/users/ezyang/b/pytorch/torch/_inductor/utils.py", line 665, in get_first_incompatible_cudagraph_node
if (val := node.meta.get("val")) is not None and free_unbacked_symbols(val):
File "/data/users/ezyang/b/pytorch/torch/fx/experimental/symbolic_shapes.py", line 447, in free_unbacked_symbols
return {s for s in free_symbols(x) if symbol_is_type(s, (SymT.UNBACKED_INT, SymT.UNBACKED_FLOAT))}
File "/data/users/ezyang/b/pytorch/torch/fx/experimental/symbolic_shapes.py", line 434, in free_symbols
first_expr = next(itr)
File "/data/users/ezyang/b/pytorch/torch/fx/experimental/symbolic_shapes.py", line 426, in _iterate_exprs
raise AssertionError(f"cannot extract sympy expressions from {val} {type(val)}")
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
AssertionError: cannot extract sympy expressions from {'in_ptr0': FakeTensor(..., device='cuda:0', size=(4,)), 'in_ptr1': FakeTensor(..., device='cuda:0', size=(4,)), 'out_ptr': FakeTensor(..., device='cuda:0', size=(4,))} <class 'torch.fx.immutable_collections.immutable_dict'>
Same error here. Normally it should pass FakeTensor(..., device='cuda:0', size=(4,))
into the iter
instead of the dict. I made different triton kernels, some will cause this error while some will not. So far, I couldn't find any consistent pattern for that.
I'm also hitting this issue when compiling a function that calls a custom Triton kernel. For me it happens whether mode
is set or not, i.e. torch.compile(my_fn, fullgraph=True)
fails. I'm using 2.5.0.dev20240617+cu121. I get
AssertionError: cannot extract sympy expressions from {'Y_ptr': FakeTensor(..., device='cuda:0', size=(32, 1, 352), dtype=torch.bfloat16)} <class 'torch.fx.immutable_collections.immutable_dict'>
,
with the same stack trace as https://github.com/pytorch/pytorch/issues/126864#issuecomment-2138413507 above.
As a quick hack I added
elif isinstance(val, dict) and len(val) == 1:
yield from _iterate_exprs(next(iter(val.values())))
to _iterate_exprs()
in symbolic_shapes.py
. Now it works fine, thought ofc I'm not sure if this might have unintended consequences.
🐛 Describe the bug
The examples in the triton+compile tutorial here fail if a non-default mode is provided:
Error logs
Minified repro
No response
Versions
cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang