intel / intel-xpu-backend-for-triton

OpenAI Triton backend for Intel® GPUs
MIT License
131 stars 39 forks source link

Support Triton 3.0 API alignment with PyTorch team #373

Closed tdeng5 closed 8 months ago

Stonepia commented 8 months ago

Edit: fixed in https://github.com/intel/intel-xpu-backend-for-triton/pull/374

Triton currently has a circular import error.

It is due to the Triton imports ipex in driver.py. ipex will then import triton_heuristics.py, and in triton_heuristic.py, it will from triton import Config. Thus, there exists a circular import error.

IMHO, Triton should not import intel_extension_for_pytorch at such an early stage. The correct import trace is supposed to let ipex import Triton, rather than Triton import ipex.

>>> import triton
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/user/user/intel-xpu-backend-for-triton/intel-xpu-backend-for-triton/python/triton/__init__.py", line 8, in <module>
    from .runtime import (
  File "/home/user/user/intel-xpu-backend-for-triton/intel-xpu-backend-for-triton/python/triton/runtime/__init__.py", line 1, in <module>
    from .autotuner import (Autotuner, Config, Heuristics, OutOfResources, autotune, heuristics)
  File "/home/user/user/intel-xpu-backend-for-triton/intel-xpu-backend-for-triton/python/triton/runtime/autotuner.py", line 7, in <module>
    from ..testing import do_bench
  File "/home/user/user/intel-xpu-backend-for-triton/intel-xpu-backend-for-triton/python/triton/testing.py", line 7, in <module>
    from . import language as tl
  File "/home/user/user/intel-xpu-backend-for-triton/intel-xpu-backend-for-triton/python/triton/language/__init__.py", line 6, in <module>
    from .standard import (
  File "/home/user/user/intel-xpu-backend-for-triton/intel-xpu-backend-for-triton/python/triton/language/standard.py", line 3, in <module>
    from ..runtime.jit import jit
  File "/home/user/user/intel-xpu-backend-for-triton/intel-xpu-backend-for-triton/python/triton/runtime/jit.py", line 10, in <module>
    from ..runtime.driver import driver
  File "/home/user/user/intel-xpu-backend-for-triton/intel-xpu-backend-for-triton/python/triton/runtime/driver.py", line 1, in <module>
    from ..backends import backends
  File "/home/user/user/intel-xpu-backend-for-triton/intel-xpu-backend-for-triton/python/triton/backends/__init__.py", line 50, in <module>
    backends = _discover_backends()
  File "/home/user/user/intel-xpu-backend-for-triton/intel-xpu-backend-for-triton/python/triton/backends/__init__.py", line 43, in _discover_backends
    compiler = _load_module(name, os.path.join(root, name, 'compiler.py'))
  File "/home/user/user/intel-xpu-backend-for-triton/intel-xpu-backend-for-triton/python/triton/backends/__init__.py", line 12, in _load_module
    spec.loader.exec_module(module)
  File "/home/user/user/intel-xpu-backend-for-triton/intel-xpu-backend-for-triton/python/triton/backends/intel/compiler.py", line 3, in <module>
    from triton.backends.intel.driver import XPUUtils
  File "/home/user/user/intel-xpu-backend-for-triton/intel-xpu-backend-for-triton/python/triton/backends/intel/driver.py", line 10, in <module>
    import intel_extension_for_pytorch as ipex
  File "/home/user/user/frameworks.ai.pytorch.ipex-gpu/intel_extension_for_pytorch/__init__.py", line 123, in <module>
    from . import _inductor
  File "/home/user/user/frameworks.ai.pytorch.ipex-gpu/intel_extension_for_pytorch/_inductor/__init__.py", line 1, in <module>
    from . import xpu
  File "/home/user/user/frameworks.ai.pytorch.ipex-gpu/intel_extension_for_pytorch/_inductor/xpu/__init__.py", line 4, in <module>
    from .codegen.triton import XPUTritonScheduling
  File "/home/user/user/frameworks.ai.pytorch.ipex-gpu/intel_extension_for_pytorch/_inductor/xpu/codegen/triton.py", line 17, in <module>
    from torch._inductor.codegen.triton import (
  File "/home/user/user/frameworks.ai.pytorch.private-gpu/torch/_inductor/codegen/triton.py", line 28, in <module>
    from ..triton_heuristics import AutotuneHint
  File "/home/user/user/frameworks.ai.pytorch.private-gpu/torch/_inductor/triton_heuristics.py", line 44, in <module>
    from triton import Config
ImportError: cannot import name 'Config' from partially initialized module 'triton' (most likely due to a circular import) (/home/user/user/intel-xpu-backend-for-triton/intel-xpu-backend-for-triton/python/triton/__init__.py)
LiyangLingIntel commented 8 months ago

Edit: fixed in #374

Triton currently has a circular import error. It is due to the Triton imports ipex in driver.py. ipex will then import triton_heuristics.py, and in triton_heuristic.py, it will from triton import Config. Thus, there exists a circular import error.

Have left a review comment in this PR#374.

After fixing this, it looks IPEX needs to align the capability with PyTorch upstream. Previously IPEX pass capability in python dictionary, current Triton Compiler requires integer type. There is an assertion here. @Stonepia Could you please take a look? https://github.com/intel/intel-xpu-backend-for-triton/blob/734343e5b724822351f40acd6aebc4f8e8b4dd42/third_party/intel/backend/compiler.py#L372-L373.

Stonepia commented 8 months ago

Edit: fixed in #374 Triton currently has a circular import error. It is due to the Triton imports ipex in driver.py. ipex will then import triton_heuristics.py, and in triton_heuristic.py, it will from triton import Config. Thus, there exists a circular import error.

Have left a review comment in this PR#374.

After fixing this, it looks IPEX needs to align the capability with PyTorch upstream. Previously IPEX pass capability in python dictionary, current Triton Compiler requires integer type. There is an assertion here. @Stonepia Could you please take a look?

https://github.com/intel/intel-xpu-backend-for-triton/blob/734343e5b724822351f40acd6aebc4f8e8b4dd42/third_party/intel/backend/compiler.py#L372-L373

.

After discussion, we temporarily return a hard-coded number int for describing capability to align with upstream. We will refine this later.

LiyangLingIntel commented 8 months ago

With the PRs https://github.com/intel/intel-xpu-backend-for-triton/pull/374 and https://github.com/intel/intel-xpu-backend-for-triton/pull/410 fixing, we can run Inductor E2E model with IPEX + Triton 3.0, there should be no problem with the APIs.