nod-ai / SHARK-ModelDev

Unified compiler/runtime for interfacing with PyTorch Dynamo.
Apache License 2.0
92 stars 46 forks source link

[Bug] Panic for the pytorch module on the shark-turbine 0.9.3 and 0.9.4 #488

Open Peefy opened 7 months ago

Peefy commented 7 months ago

My shark-turbine version is 0.9.3 and 0.9.4, it will panic for the following code

import torch

print("\nInstalled PyTorch, version:", torch.__version__)

torch.manual_seed(0)

class LinearModule(torch.nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.weight = torch.nn.Parameter(torch.randn(in_features, out_features))
        self.bias = torch.nn.Parameter(torch.randn(out_features))

    def forward(self, input):
        return (input @ self.weight) + self.bias

linear_module = LinearModule(4, 3)
opt_linear_module = (
    torch.compile(linear_module, backend="turbine_cpu")
)
print("Compiled module using Turbine. New module type is", type(opt_linear_module))
args = torch.randn(4)
turbine_output = opt_linear_module(args)

print("Weight:", linear_module.weight)
print("Bias:", linear_module.bias)
print("Args:", args)
print("Output:", turbine_output)

However, shark-turbine 0.9.2 works well

ScottTodd commented 7 months ago

Can you clarify what you mean by "it will panic"? Does your Python interpreter crash? Is there an error message? (I usually associate "panic" with "kernel panic", which would be very unexpected here)

Peefy commented 7 months ago

Sorry, my python version is Python 3.11.8, the PyTorch version: 2.2.1 and the iree version 20240228.815. The error message is as follows:

Installed PyTorch, version: 2.2.1
Traceback (most recent call last):
  File "/Users/lingzhi/_Code/KCLOpenSource/kcl/a.py", line 20, in <module>
    torch.compile(linear_module, backend="turbine_cpu")
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/torch/__init__.py", line 1824, in compile
    backend = _TorchCompileWrapper(backend, mode, options, dynamic)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/torch/__init__.py", line 1692, in __init__
    self.compiler_fn = lookup_backend(backend)
                       ^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/torch/_dynamo/backends/registry.py", line 58, in lookup_backend
    _lazy_import_entry_point(compiler_fn)
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/torch/_dynamo/backends/registry.py", line 110, in _lazy_import_entry_point
    compiler_fn = backend_eps[backend_name].load()
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/importlib/metadata/__init__.py", line 202, in load
    module = import_module(match.group('module'))
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/importlib/__init__.py", line 126, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<frozen importlib._bootstrap>", line 1204, in _gcd_import
  File "<frozen importlib._bootstrap>", line 1176, in _find_and_load
  File "<frozen importlib._bootstrap>", line 1147, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 690, in _load_unlocked
  File "<frozen importlib._bootstrap_external>", line 940, in exec_module
  File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/shark_turbine/dynamo/backends/cpu.py", line 40, in <module>
    from ..passes import turbine_cpu_pass_pipeline
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/shark_turbine/dynamo/passes.py", line 56, in <module>
    @register_decomposition(torch.ops.aten._scaled_dot_product_flash_attention.default)
     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/torch/_decomp/__init__.py", line 185, in decomposition_decorator
    pytree.tree_map_(register, aten_op)
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/torch/utils/_pytree.py", line 607, in tree_map_
    deque(map(func, flat_args), maxlen=0)  # consume and exhaust the iterable
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/torch/_decomp/__init__.py", line 182, in register
    _add_op_to_registry(registry, op, fn)
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/torch/_decomp/__init__.py", line 55, in _add_op_to_registry
    raise RuntimeError(f"duplicate registrations for {op_overload}")
RuntimeError: duplicate registrations for aten._scaled_dot_product_flash_attention.default
stellaraccident commented 7 months ago

@aviator19941 another example. Can you please verify this does not happen at head, and I can push a release tomorrow.

bailuan commented 6 months ago

encounter same problem under torch==2.2.0a0+git6c8c5ad.

Peefy commented 6 months ago

I've tried shark-turbine==0.9.6, it works fine. Has the issue been resolved?

bailuan commented 6 months ago

sharktubine version == ad64c3ff31c1c2d8571d75984403326039636602 (v0.9.6) pytorch version == 8ac9b20d4b090c213799e81acf48a55ea8d437d6 (v2.2.0)

command : python sd_test.py -k testExportVaeModelDecode

error still happened.

Traceback (most recent call last): File "/workspace/bailuan/offcial_sharkturbine/SHARK-Turbine/models/turbine_models/tests/sd_test.py", line 9, in from turbine_models.custom_models.sd_inference import ( File "/workspace/bailuan/envs/sd_shark_2/lib/python3.10/site-packages/turbine_models-0.9.6-py3.10.egg/turbine_models/custom_models/sd_inference/clip.py", line 14, in File "/workspace/bailuan/envs/sd_shark_2/lib/python3.10/site-packages/shark_turbine-0.9.6-py3.10.egg/shark_turbine/aot/init.py", line 7, in from .compiled_module import CompiledModule File "/workspace/bailuan/envs/sd_shark_2/lib/python3.10/site-packages/shark_turbine-0.9.6-py3.10.egg/shark_turbine/aot/compiled_module.py", line 18, in from . import builtins File "/workspace/bailuan/envs/sd_shark_2/lib/python3.10/site-packages/shark_turbine-0.9.6-py3.10.egg/shark_turbine/aot/builtins/init.py", line 8, in from .jittable import jittable File "/workspace/bailuan/envs/sd_shark_2/lib/python3.10/site-packages/shark_turbine-0.9.6-py3.10.egg/shark_turbine/aot/builtins/jittable.py", line 38, in from ...dynamo.passes import ( File "/workspace/bailuan/envs/sd_shark_2/lib/python3.10/site-packages/shark_turbine-0.9.6-py3.10.egg/shark_turbine/dynamo/passes.py", line 4, in from shark_turbine.dynamo import utils File "/workspace/bailuan/envs/sd_shark_2/lib/python3.10/site-packages/shark_turbine-0.9.6-py3.10.egg/shark_turbine/dynamo/utils.py", line 17, in def scaled_dot_product_flash_attention( File "/workspace/bailuan/20231221_pytorch_install/pytorch/torch/_decomp/init.py", line 185, in decomposition_decorator pytree.treemap(register, aten_op) File "/workspace/bailuan/20231221_pytorch_install/pytorch/torch/utils/_pytree.py", line 607, in treemap deque(map(func, flat_args), maxlen=0) # consume and exhaust the iterable File "/workspace/bailuan/20231221_pytorch_install/pytorch/torch/_decomp/init.py", line 182, in register _add_op_to_registry(registry, op, fn) File "/workspace/bailuan/20231221_pytorch_install/pytorch/torch/_decomp/init.py", line 55, in _add_op_to_registry raise RuntimeError(f"duplicate registrations for {op_overload}") RuntimeError: duplicate registrations for aten._scaled_dot_product_flash_attention.default

bailuan commented 6 months ago

Pytorch installed from pip works fine, but installed from source failed.

stellaraccident commented 6 months ago

I really don't understand how this is continuing to happen.

I'm the future, we are definitely not going to hack decompositions in like that again.

Unless if something is lying about versions, I don't see how the problematic path is being triggered: https://github.com/nod-ai/SHARK-Turbine/blob/6e3adb39ffbad5df74a12b3732cf852a9454aaf4/core/shark_turbine/dynamo/utils.py#L13

bailuan commented 6 months ago

@stellaraccident Have you tested it under source-compiled torch>2.1.0 without pip install ? Is it ok?

stellaraccident commented 6 months ago

I have not had to use a source compiled pytorch for a long time -- so no.

bailuan commented 6 months ago

Oh well, there might indeed be an issue here.

stellaraccident commented 6 months ago

I could believe that something is wonky with torch.version for source builds. If you have one handy, wouldn't mind some poking.

Sorry - I misread the above comment in my haste and thought you were talking about a turbine install from source.

bailuan commented 6 months ago

Actually, I have a question about turbine from source. That is: I install package shark-turbine/turbine-models/turbine-serving via pip setup.py install in there three folders. It works fine. But during my debuging, it is very anoyying because I can't add a breakpoint in source code. So, why turbine user have to install package from source code, I would better like use it directly imported package but not installed package.

stellaraccident commented 6 months ago

You can add it to your PYTHONPATH directly if you want. Or pip install -e core (-e is for editable). We just presented a way to get started but all of the usual way to use a python package apply. There are a lot of ways, sadly, to use such things, and a lot of people have personal preferences.