pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
82.55k stars 22.22k forks source link

Compile with non-default mode + triton kernel fails #126864

Open RaulPPelaez opened 4 months ago

RaulPPelaez commented 4 months ago

🐛 Describe the bug

The examples in the triton+compile tutorial here fail if a non-default mode is provided:

import torch
from torch.utils._triton import has_triton

if not has_triton():
    print("Skipping because triton is not supported on this device.")
else:
    import triton
    from triton import language as tl

    @triton.autotune(
        configs=[
            triton.Config({"BLOCK_SIZE": 4}, num_stages=3, num_warps=8),
            triton.Config({"BLOCK_SIZE": 4}, num_stages=4, num_warps=4),
            triton.Config({"BLOCK_SIZE": 2}, num_stages=3, num_warps=8),
            triton.Config({"BLOCK_SIZE": 2}, num_stages=4, num_warps=4),
        ],
        key=[],
    )
    @triton.jit
    def add_kernel_autotuned(
        in_ptr0,
        in_ptr1,
        out_ptr,
        n_elements,
        BLOCK_SIZE: "tl.constexpr",
    ):
        pid = tl.program_id(axis=0)
        block_start = pid * BLOCK_SIZE
        offsets = block_start + tl.arange(0, BLOCK_SIZE)
        mask = offsets < n_elements
        x = tl.load(in_ptr0 + offsets, mask=mask)
        y = tl.load(in_ptr1 + offsets, mask=mask)
        output = x + y
        tl.store(out_ptr + offsets, output, mask=mask)

    @torch.compile(fullgraph=True, mode="reduce-overhead") # FAILS if mode is provided
    def add_fn(x, y):
        output = torch.zeros_like(x)
        n_elements = output.numel()
        grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
        add_kernel_autotuned[grid](x, y, output, n_elements)
        return output

    x = torch.randn(4, device="cuda")
    y = torch.randn(4, device="cuda")
    out = add_fn(x, y)
    print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}")

Error logs

File "lib/python3.11/site-packages/torch/fx/experimental/symbolic_shapes.py", line 422, in _iterate_exprs
    raise AssertionError(f"cannot extract sympy expressions from {val} {type(val)}")
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
AssertionError: cannot extract sympy expressions from {'out_ptr': FakeTensor(..., device='cuda:0', size=(4,))} <class 'torch.fx.immutable_collections.immutable_dict'>

Minified repro

No response

Versions

[conda] pytorch                   2.4.0.dev20240508 py3.11_cuda12.4_cudnn8.9.2_0    pytorch-nightly
[conda] pytorch-cuda              12.4                 hc786d27_6    pytorch-nightly
[conda] pytorch-mutex             1.0                        cuda    pytorch-nightly
[conda] torchtriton               3.0.0+45fff310c8           py311    pytorch-nightly

cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang

ezyang commented 4 months ago

Longer backtrace:

  File "/home/ezyang/local/b/pytorch-env/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/ezyang/local/b/pytorch-env/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/data/users/ezyang/b/pytorch/torch/_dynamo/utils.py", line 218, in time_wrapper
    r = func(*args, **kwargs)
  File "/data/users/ezyang/b/pytorch/torch/_inductor/compile_fx.py", line 554, in compile_fx_inner
    (not has_incompatible_cudagraph_ops(gm), "incompatible ops"),
  File "/data/users/ezyang/b/pytorch/torch/_inductor/utils.py", line 671, in has_incompatible_cudagraph_ops
    return get_first_incompatible_cudagraph_node(gm) is not None
  File "/data/users/ezyang/b/pytorch/torch/_inductor/utils.py", line 665, in get_first_incompatible_cudagraph_node
    if (val := node.meta.get("val")) is not None and free_unbacked_symbols(val):
  File "/data/users/ezyang/b/pytorch/torch/fx/experimental/symbolic_shapes.py", line 447, in free_unbacked_symbols
    return {s for s in free_symbols(x) if symbol_is_type(s, (SymT.UNBACKED_INT, SymT.UNBACKED_FLOAT))}
  File "/data/users/ezyang/b/pytorch/torch/fx/experimental/symbolic_shapes.py", line 434, in free_symbols
    first_expr = next(itr)
  File "/data/users/ezyang/b/pytorch/torch/fx/experimental/symbolic_shapes.py", line 426, in _iterate_exprs
    raise AssertionError(f"cannot extract sympy expressions from {val} {type(val)}")
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
AssertionError: cannot extract sympy expressions from {'in_ptr0': FakeTensor(..., device='cuda:0', size=(4,)), 'in_ptr1': FakeTensor(..., device='cuda:0', size=(4,)), 'out_ptr': FakeTensor(..., device='cuda:0', size=(4,))} <class 'torch.fx.immutable_collections.immutable_dict'>
Luke20000429 commented 3 months ago

Same error here. Normally it should pass FakeTensor(..., device='cuda:0', size=(4,)) into the iter instead of the dict. I made different triton kernels, some will cause this error while some will not. So far, I couldn't find any consistent pattern for that.

oscarkey commented 3 months ago

I'm also hitting this issue when compiling a function that calls a custom Triton kernel. For me it happens whether mode is set or not, i.e. torch.compile(my_fn, fullgraph=True) fails. I'm using 2.5.0.dev20240617+cu121. I get

AssertionError: cannot extract sympy expressions from {'Y_ptr': FakeTensor(..., device='cuda:0', size=(32, 1, 352), dtype=torch.bfloat16)} <class 'torch.fx.immutable_collections.immutable_dict'>,

with the same stack trace as https://github.com/pytorch/pytorch/issues/126864#issuecomment-2138413507 above.

As a quick hack I added

elif isinstance(val, dict) and len(val) == 1:
        yield from _iterate_exprs(next(iter(val.values())))

to _iterate_exprs()in symbolic_shapes.py. Now it works fine, thought ofc I'm not sure if this might have unintended consequences.