pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
83.08k stars 22.41k forks source link

PyTorch 2.0 torch.compile fails with unfold, unsqueeze, permute sequence for negative index #91505

Closed m-lyon closed 1 year ago

m-lyon commented 1 year ago

🐛 Describe the bug

torch.compile fails when given the specific sequence of operations

import torch

class Model(torch.nn.Module):

    def forward(self, tensor: torch.Tensor):  # (4, 4)
        tensor = tensor.unfold(0, 2, 1)  # (3, 4, 2)
        tensor = tensor.unsqueeze(1)  # (3, 1, 4, 2)
        tensor = tensor.permute([0, 2, 3, -3])  # (3, 4, 2, 1)

        return tensor

compiled_model = torch.compile(Model())
a_in = torch.randn((4, 4))
a_out = compiled_model(a_in)

Importantly, swapping out tensor = tensor.permute([0, 2, 3, -3]) for tensor = tensor.permute([0, 2, 3, 1]) stops the compiler from crashing. Additionally removing the tensor.unfold and tensor.unsqueeze calls (and adjusting the dimension sizes accordingly) stops the compiler from crashing.

Below is the full traceback

/home/matt/anaconda3/envs/torch2/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py:372: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled.Consider setting `torch.set_float32_matmul_precision('high')`
  warnings.warn(
Traceback (most recent call last):
  File "/home/matt/anaconda3/envs/torch2/lib/python3.8/site-packages/torch/fx/interpreter.py", line 136, in run
    self.env[node] = self.run_node(node)
  File "/home/matt/anaconda3/envs/torch2/lib/python3.8/site-packages/torch/_inductor/graph.py", line 369, in run_node
    result = super().run_node(n)
  File "/home/matt/anaconda3/envs/torch2/lib/python3.8/site-packages/torch/fx/interpreter.py", line 177, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/home/matt/anaconda3/envs/torch2/lib/python3.8/site-packages/torch/_inductor/graph.py", line 338, in output
    self.graph_outputs = [ir.ExternKernel.realize_input(x) for x in result]
  File "/home/matt/anaconda3/envs/torch2/lib/python3.8/site-packages/torch/_inductor/graph.py", line 338, in <listcomp>
    self.graph_outputs = [ir.ExternKernel.realize_input(x) for x in result]
  File "/home/matt/anaconda3/envs/torch2/lib/python3.8/site-packages/torch/_inductor/ir.py", line 2502, in realize_input
    return cls.realize_input(x.data)
  File "/home/matt/anaconda3/envs/torch2/lib/python3.8/site-packages/torch/_inductor/ir.py", line 2511, in realize_input
    return cls.convert_to_reinterpret_view(x)
  File "/home/matt/anaconda3/envs/torch2/lib/python3.8/site-packages/torch/_inductor/ir.py", line 2459, in convert_to_reinterpret_view
    rw = extract_read_writes(x.make_loader(), x.get_size(), normalize=False)
  File "/home/matt/anaconda3/envs/torch2/lib/python3.8/site-packages/torch/_inductor/ir.py", line 1206, in make_loader
    inv = [inv[i] for i in range(len(self.dims))]
  File "/home/matt/anaconda3/envs/torch2/lib/python3.8/site-packages/torch/_inductor/ir.py", line 1206, in <listcomp>
    inv = [inv[i] for i in range(len(self.dims))]
KeyError: '1\n\nWhile executing return (permute,)\nOriginal traceback:\nNone'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/matt/anaconda3/envs/torch2/lib/python3.8/site-packages/torch/_dynamo/output_graph.py", line 676, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.fake_example_inputs())
  File "/home/matt/anaconda3/envs/torch2/lib/python3.8/site-packages/torch/_dynamo/debug_utils.py", line 1032, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs, **kwargs)
  File "/home/matt/anaconda3/envs/torch2/lib/python3.8/site-packages/torch/__init__.py", line 1151, in __call__
    return self.compile_fn(model_, inputs_)
  File "/home/matt/anaconda3/envs/torch2/lib/python3.8/site-packages/torch/_inductor/compile_fx.py", line 398, in compile_fx
    return aot_autograd(
  File "/home/matt/anaconda3/envs/torch2/lib/python3.8/site-packages/torch/_dynamo/optimizations/training.py", line 78, in compiler_fn
    cg = aot_module_simplified(gm, example_inputs, **kwargs)
  File "/home/matt/anaconda3/envs/torch2/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 2353, in aot_module_simplified
    compiled_fn = create_aot_dispatcher_function(
  File "/home/matt/anaconda3/envs/torch2/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 88, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/matt/anaconda3/envs/torch2/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 2050, in create_aot_dispatcher_function
    compiled_fn = compiler_fn(flat_fn, fake_flat_tensor_args, aot_config)
  File "/home/matt/anaconda3/envs/torch2/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 1305, in aot_wrapper_dedupe
    return compiler_fn(flat_fn, leaf_flat_args, aot_config)
  File "/home/matt/anaconda3/envs/torch2/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 955, in aot_dispatch_base
    compiled_fw = aot_config.fw_compiler(fw_module, flat_args)
  File "/home/matt/anaconda3/envs/torch2/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 88, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/matt/anaconda3/envs/torch2/lib/python3.8/site-packages/torch/_inductor/compile_fx.py", line 373, in fw_compiler
    return inner_compile(
  File "/home/matt/anaconda3/envs/torch2/lib/python3.8/site-packages/torch/_dynamo/debug_utils.py", line 588, in debug_wrapper
    compiled_fn = compiler_fn(gm, example_inputs, **kwargs)
  File "/home/matt/anaconda3/envs/torch2/lib/python3.8/site-packages/torch/_inductor/debug.py", line 223, in inner
    return fn(*args, **kwargs)
  File "/home/matt/anaconda3/envs/torch2/lib/python3.8/contextlib.py", line 75, in inner
    return func(*args, **kwds)
  File "/home/matt/anaconda3/envs/torch2/lib/python3.8/site-packages/torch/_inductor/compile_fx.py", line 139, in compile_fx_inner
    graph.run(*example_inputs)
  File "/home/matt/anaconda3/envs/torch2/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 88, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/matt/anaconda3/envs/torch2/lib/python3.8/site-packages/torch/_inductor/graph.py", line 170, in run
    return super().run(*args)
  File "/home/matt/anaconda3/envs/torch2/lib/python3.8/site-packages/torch/fx/interpreter.py", line 143, in run
    raise RuntimeError(*e.args) from e
RuntimeError: 1

While executing return (permute,)
Original traceback:
None

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "permute_problem.py", line 15, in <module>
    a_out = compiled_model(a_in)
  File "/home/matt/anaconda3/envs/torch2/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1482, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/matt/anaconda3/envs/torch2/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 83, in forward
    return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
  File "/home/matt/anaconda3/envs/torch2/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 212, in _fn
    return fn(*args, **kwargs)
  File "/home/matt/anaconda3/envs/torch2/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 333, in catch_errors
    return callback(frame, cache_size, hooks)
  File "/home/matt/anaconda3/envs/torch2/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 480, in _convert_frame
    result = inner_convert(frame, cache_size, hooks)
  File "/home/matt/anaconda3/envs/torch2/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 103, in _fn
    return fn(*args, **kwargs)
  File "/home/matt/anaconda3/envs/torch2/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 88, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/matt/anaconda3/envs/torch2/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 339, in _convert_frame_assert
    return _compile(
  File "/home/matt/anaconda3/envs/torch2/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 400, in _compile
    out_code = transform_code_object(code, transform)
  File "/home/matt/anaconda3/envs/torch2/lib/python3.8/site-packages/torch/_dynamo/bytecode_transformation.py", line 341, in transform_code_object
    transformations(instructions, code_options)
  File "/home/matt/anaconda3/envs/torch2/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 387, in transform
    tracer.run()
  File "/home/matt/anaconda3/envs/torch2/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 1684, in run
    super().run()
  File "/home/matt/anaconda3/envs/torch2/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 538, in run
    and self.step()
  File "/home/matt/anaconda3/envs/torch2/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 501, in step
    getattr(self, inst.opname)(inst)
  File "/home/matt/anaconda3/envs/torch2/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 1750, in RETURN_VALUE
    self.output.compile_subgraph(self)
  File "/home/matt/anaconda3/envs/torch2/lib/python3.8/site-packages/torch/_dynamo/output_graph.py", line 529, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "/home/matt/anaconda3/envs/torch2/lib/python3.8/site-packages/torch/_dynamo/output_graph.py", line 600, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/home/matt/anaconda3/envs/torch2/lib/python3.8/site-packages/torch/_dynamo/output_graph.py", line 681, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: debug_wrapper raised RuntimeError: 1

While executing return (permute,)
Original traceback:
None

Set torch._dynamo.config.verbose=True for more information

You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True

Versions

PyTorch version: 2.0.0.dev20221225+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.1 LTS (x86_64)
GCC version: (Ubuntu 9.5.0-1ubuntu1~22.04) 9.5.0
Clang version: Could not collect
CMake version: version 3.25.0
Libc version: glibc-2.35

Python version: 3.8.12 (default, Oct 12 2021, 13:49:34)  [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-5.15.0-56-generic-x86_64-with-glibc2.17
Is CUDA available: True
CUDA runtime version: 11.7.99
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3080
Nvidia driver version: 515.86.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.24.0
[pip3] phd-torch==1.0.0
[pip3] phd-torchscripts==1.0.0
[pip3] pytorch-lightning==1.7.7
[pip3] pytorch-memlab==0.2.4
[pip3] se3-transformer-pytorch==0.8.13
[pip3] torch==2.0.0.dev20221225+cu117
[pip3] torchaudio==2.0.0.dev20221222+cu117
[pip3] torchConvNd==0.2.0
[pip3] torchinfo==1.7.0
[pip3] torchmetrics==0.7.2
[pip3] torchtriton==2.0.0+0d7e753227
[pip3] torchvision==0.15.0.dev20221222+cu117
[conda] blas                      1.0                         mkl  
[conda] ffmpeg                    4.3                  hf484d3e_0    pytorch
[conda] mkl                       2021.4.0           h06a4308_640  
[conda] mkl-service               2.4.0            py38h7f8727e_0  
[conda] mkl_fft                   1.3.1            py38hd3c417c_0  
[conda] mkl_random                1.2.2            py38h51133e4_0  
[conda] numpy                     1.19.2                   pypi_0    pypi
[conda] phd-torch                 1.0.0                     dev_0    <develop>
[conda] phd-torchscripts          1.0.0                     dev_0    <develop>
[conda] pytorch-cuda              11.7                 h67b0de4_1    pytorch
[conda] pytorch-lightning         1.7.7                    pypi_0    pypi
[conda] pytorch-memlab            0.2.4                    pypi_0    pypi
[conda] pytorch-mutex             1.0                        cuda    pytorch
[conda] se3-transformer-pytorch   0.8.13                   pypi_0    pypi
[conda] torch                     2.0.0.dev20221225+cu117          pypi_0    pypi
[conda] torchaudio                2.0.0.dev20221222+cu117          pypi_0    pypi
[conda] torchconvnd               0.2.0                    pypi_0    pypi
[conda] torchinfo                 1.7.0                    pypi_0    pypi
[conda] torchmetrics              0.7.2                    pypi_0    pypi
[conda] torchtriton               2.0.0+0d7e753227          pypi_0    pypi
[conda] torchvision               0.15.0.dev20221222+cu117          pypi_0    pypi

cc @ezyang @soumith @msaroufim @wconstab @ngimel @bdhirsh @mlazos @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @desertfire

MDK8888 commented 1 year ago

Hey, I would love to work on this! I was wondering if there were any hints or guides to get started with this-thanks so much!

ngimel commented 1 year ago

Here https://github.com/pytorch/pytorch/blob/cce577b39154b501705f32ee0392c77eee43820b/torch/_inductor/ir.py#L1179 add dims = cls._map_neg_dims(dims), modify assert line accordingly and add a test as the failing example above to test_torchinductor.py

MDK8888 commented 1 year ago

Ah, I see. Originally, I was thinking about fixing line 1206 below, but I think that at most that would only fix what's wrong with tensor.permute() and not tensor.unfold() and tensor.unsqueeze(). Thank you so much for referring me to this, I really appreciate it!

https://github.com/pytorch/pytorch/blob/7ef7c57ae7d85137daeabdb1e2d4b28c1b62ce4b/torch/_inductor/ir.py#L1206