intel / intel-xpu-backend-for-triton

OpenAI Triton backend for Intel® GPUs
MIT License
122 stars 33 forks source link

Sample Code throws Error: NotImplementedError: Could not run 'aten::empty.memory_format' with arguments from the 'XPU' backend #1640

Open gurwinderintel opened 1 month ago

gurwinderintel commented 1 month ago

Error

Ran Sample Code on Intel PVC:

import torch
from torch._dynamo.testing import rand_strided
from torch.nn import *
class simpleModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # tensors inside model should be on xpu
        self.y = rand_strided((32, 8), (8, 1), device='xpu', dtype=torch.float32)

    def forward(self, x):
        z = x + self.y
        return z

# tensors passed to the model should be on xpu
x = rand_strided((32, 8), (8, 1), device='xpu', dtype=torch.float32)
xpu_model = simpleModel()
# Call torch.compile for optimization
optimized_mod = torch.compile(xpu_model)
graph_result = optimized_mod(x)

print(graph_result)
Traceback (most recent call last):
  File "/home/gta/Triton.py", line 16, in <module>
    x = rand_strided((32, 8), (8, 1), device='xpu', dtype=torch.float32)
  File "/root/miniconda3/envs/intel-prof/lib/python3.10/site-packages/torch/_dynamo/testing.py", line 301, in rand_strided
    buffer = torch.randn(needed_size, dtype=dtype, device=device)
NotImplementedError: Could not run 'aten::empty.memory_format' with arguments from the 'XPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::empty.memory_format' is only available for these backends: [CPU, Meta, QuantizedCPU, QuantizedMeta, MkldnnCPU, SparseCPU, SparseMeta, SparseCsrCPU, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradHIP, AutogradXLA, AutogradMPS, AutogradIPU, AutogradXPU, AutogradHPU, AutogradVE, AutogradLazy, AutogradMTIA, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, AutogradMeta, AutogradNestedTensor, Tracer, AutocastCPU, AutocastCUDA, FuncTorchBatched, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PreDispatch, PythonDispatcher].

CPU: registered at /workspace/repositories/IPEX/pytorch/build/aten/src/ATen/RegisterCPU.cpp:31188 [kernel]
Meta: registered at /workspace/repositories/IPEX/pytorch/build/aten/src/ATen/RegisterMeta.cpp:26829 [kernel]
QuantizedCPU: registered at /workspace/repositories/IPEX/pytorch/build/aten/src/ATen/RegisterQuantizedCPU.cpp:951 [kernel]
QuantizedMeta: registered at /workspace/repositories/IPEX/pytorch/build/aten/src/ATen/RegisterQuantizedMeta.cpp:105 [kernel]
MkldnnCPU: registered at /workspace/repositories/IPEX/pytorch/build/aten/src/ATen/RegisterMkldnnCPU.cpp:515 [kernel]
SparseCPU: registered at /workspace/repositories/IPEX/pytorch/build/aten/src/ATen/RegisterSparseCPU.cpp:1387 [kernel]
SparseMeta: registered at /workspace/repositories/IPEX/pytorch/build/aten/src/ATen/RegisterSparseMeta.cpp:249 [kernel]
SparseCsrCPU: registered at /workspace/repositories/IPEX/pytorch/build/aten/src/ATen/RegisterSparseCsrCPU.cpp:1135 [kernel]
BackendSelect: registered at /workspace/repositories/IPEX/pytorch/build/aten/src/ATen/RegisterBackendSelect.cpp:742 [kernel]
Python: registered at /workspace/repositories/IPEX/pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:153 [backend fallback]
FuncTorchDynamicLayerBackMode: registered at /workspace/repositories/IPEX/pytorch/aten/src/ATen/functorch/DynamicLayer.cpp:498 [backend fallback]
Functionalize: registered at /workspace/repositories/IPEX/pytorch/aten/src/ATen/FunctionalizeFallbackKernel.cpp:290 [backend fallback]
Named: registered at /workspace/repositories/IPEX/pytorch/aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]
Conjugate: fallthrough registered at /workspace/repositories/IPEX/pytorch/aten/src/ATen/ConjugateFallback.cpp:21 [kernel]
Negative: fallthrough registered at /workspace/repositories/IPEX/pytorch/aten/src/ATen/native/NegateFallback.cpp:23 [kernel]
ZeroTensor: fallthrough registered at /workspace/repositories/IPEX/pytorch/aten/src/ATen/ZeroTensorFallback.cpp:90 [kernel]
ADInplaceOrView: fallthrough registered at /workspace/repositories/IPEX/pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:86 [backend fallback]
AutogradOther: registered at /workspace/repositories/IPEX/pytorch/torch/csrc/autograd/generated/VariableType_2.cpp:18610 [autograd kernel]
AutogradCPU: registered at /workspace/repositories/IPEX/pytorch/torch/csrc/autograd/generated/VariableType_2.cpp:18610 [autograd kernel]
AutogradCUDA: registered at /workspace/repositories/IPEX/pytorch/torch/csrc/autograd/generated/VariableType_2.cpp:18610 [autograd kernel]
AutogradHIP: registered at /workspace/repositories/IPEX/pytorch/torch/csrc/autograd/generated/VariableType_2.cpp:18610 [autograd kernel]
AutogradXLA: registered at /workspace/repositories/IPEX/pytorch/torch/csrc/autograd/generated/VariableType_2.cpp:18610 [autograd kernel]
AutogradMPS: registered at /workspace/repositories/IPEX/pytorch/torch/csrc/autograd/generated/VariableType_2.cpp:18610 [autograd kernel]
AutogradIPU: registered at /workspace/repositories/IPEX/pytorch/torch/csrc/autograd/generated/VariableType_2.cpp:18610 [autograd kernel]
AutogradXPU: registered at /workspace/repositories/IPEX/pytorch/torch/csrc/autograd/generated/VariableType_2.cpp:18610 [autograd kernel]
AutogradHPU: registered at /workspace/repositories/IPEX/pytorch/torch/csrc/autograd/generated/VariableType_2.cpp:18610 [autograd kernel]
AutogradVE: registered at /workspace/repositories/IPEX/pytorch/torch/csrc/autograd/generated/VariableType_2.cpp:18610 [autograd kernel]
AutogradLazy: registered at /workspace/repositories/IPEX/pytorch/torch/csrc/autograd/generated/VariableType_2.cpp:18610 [autograd kernel]
AutogradMTIA: registered at /workspace/repositories/IPEX/pytorch/torch/csrc/autograd/generated/VariableType_2.cpp:18610 [autograd kernel]
AutogradPrivateUse1: registered at /workspace/repositories/IPEX/pytorch/torch/csrc/autograd/generated/VariableType_2.cpp:18610 [autograd kernel]
AutogradPrivateUse2: registered at /workspace/repositories/IPEX/pytorch/torch/csrc/autograd/generated/VariableType_2.cpp:18610 [autograd kernel]
AutogradPrivateUse3: registered at /workspace/repositories/IPEX/pytorch/torch/csrc/autograd/generated/VariableType_2.cpp:18610 [autograd kernel]
AutogradMeta: registered at /workspace/repositories/IPEX/pytorch/torch/csrc/autograd/generated/VariableType_2.cpp:18610 [autograd kernel]
AutogradNestedTensor: registered at /workspace/repositories/IPEX/pytorch/torch/csrc/autograd/generated/VariableType_2.cpp:18610 [autograd kernel]
Tracer: registered at /workspace/repositories/IPEX/pytorch/torch/csrc/autograd/generated/TraceType_2.cpp:17079 [kernel]
AutocastCPU: fallthrough registered at /workspace/repositories/IPEX/pytorch/aten/src/ATen/autocast_mode.cpp:382 [backend fallback]
AutocastCUDA: fallthrough registered at /workspace/repositories/IPEX/pytorch/aten/src/ATen/autocast_mode.cpp:249 [backend fallback]
FuncTorchBatched: registered at /workspace/repositories/IPEX/pytorch/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp:710 [backend fallback]
FuncTorchVmapMode: fallthrough registered at /workspace/repositories/IPEX/pytorch/aten/src/ATen/functorch/VmapModeRegistrations.cpp:28 [backend fallback]
Batched: registered at /workspace/repositories/IPEX/pytorch/aten/src/ATen/LegacyBatchingRegistrations.cpp:1075 [backend fallback]
VmapMode: fallthrough registered at /workspace/repositories/IPEX/pytorch/aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]
FuncTorchGradWrapper: registered at /workspace/repositories/IPEX/pytorch/aten/src/ATen/functorch/TensorWrapper.cpp:203 [backend fallback]
PythonTLSSnapshot: registered at /workspace/repositories/IPEX/pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:161 [backend fallback]
FuncTorchDynamicLayerFrontMode: registered at /workspace/repositories/IPEX/pytorch/aten/src/ATen/functorch/DynamicLayer.cpp:494 [backend fallback]
PreDispatch: registered at /workspace/repositories/IPEX/pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:165 [backend fallback]
PythonDispatcher: registered at /workspace/repositories/IPEX/pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:157 [backend fallback]

If I use import intel_extension_for_pytorch as ipex then it throw this Error:


Traceback (most recent call last):
  File "/home/gta/Triton.py", line 20, in <module>
    graph_result = optimized_mod(x)
  File "/root/miniconda3/envs/intel-prof/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/miniconda3/envs/intel-prof/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/miniconda3/envs/intel-prof/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
    return fn(*args, **kwargs)
  File "/root/miniconda3/envs/intel-prof/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/miniconda3/envs/intel-prof/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/miniconda3/envs/intel-prof/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 490, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
  File "/root/miniconda3/envs/intel-prof/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 641, in _convert_frame
    result = inner_convert(frame, cache_size, hooks, frame_state)
  File "/root/miniconda3/envs/intel-prof/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 133, in _fn
    return fn(*args, **kwargs)
  File "/root/miniconda3/envs/intel-prof/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 389, in _convert_frame_assert
    return _compile(
  File "/root/miniconda3/envs/intel-prof/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 569, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/root/miniconda3/envs/intel-prof/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 189, in time_wrapper
    r = func(*args, **kwargs)
  File "/root/miniconda3/envs/intel-prof/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 491, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/root/miniconda3/envs/intel-prof/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
    transformations(instructions, code_options)
  File "/root/miniconda3/envs/intel-prof/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 458, in transform
    tracer.run()
  File "/root/miniconda3/envs/intel-prof/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2074, in run
    super().run()
  File "/root/miniconda3/envs/intel-prof/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 724, in run
    and self.step()
  File "/root/miniconda3/envs/intel-prof/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in step
    getattr(self, inst.opname)(inst)
  File "/root/miniconda3/envs/intel-prof/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2162, in RETURN_VALUE
    self.output.compile_subgraph(
  File "/root/miniconda3/envs/intel-prof/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 833, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "/root/miniconda3/envs/intel-prof/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/root/miniconda3/envs/intel-prof/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 957, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/root/miniconda3/envs/intel-prof/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 189, in time_wrapper
    r = func(*args, **kwargs)
  File "/root/miniconda3/envs/intel-prof/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1024, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/root/miniconda3/envs/intel-prof/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1009, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
  File "/root/miniconda3/envs/intel-prof/lib/python3.10/site-packages/torch/_dynamo/repro/after_dynamo.py", line 117, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/root/miniconda3/envs/intel-prof/lib/python3.10/site-packages/torch/__init__.py", line 1568, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
  File "/root/miniconda3/envs/intel-prof/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1150, in compile_fx
    return aot_autograd(
  File "/root/miniconda3/envs/intel-prof/lib/python3.10/site-packages/torch/_dynamo/backends/common.py", line 55, in compiler_fn
    cg = aot_module_simplified(gm, example_inputs, **kwargs)
  File "/root/miniconda3/envs/intel-prof/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 3891, in aot_module_simplified
    compiled_fn = create_aot_dispatcher_function(
  File "/root/miniconda3/envs/intel-prof/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 189, in time_wrapper
    r = func(*args, **kwargs)
  File "/root/miniconda3/envs/intel-prof/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 3429, in create_aot_dispatcher_function
    compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config, fw_metadata=fw_metadata)
  File "/root/miniconda3/envs/intel-prof/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 2212, in aot_wrapper_dedupe
    return compiler_fn(flat_fn, leaf_flat_args, aot_config, fw_metadata=fw_metadata)
  File "/root/miniconda3/envs/intel-prof/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 2392, in aot_wrapper_synthetic_base
    return compiler_fn(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata)
  File "/root/miniconda3/envs/intel-prof/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1573, in aot_dispatch_base
    compiled_fw = compiler(fw_module, flat_args)
  File "/root/miniconda3/envs/intel-prof/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 189, in time_wrapper
    r = func(*args, **kwargs)
  File "/root/miniconda3/envs/intel-prof/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1092, in fw_compiler_base
    return inner_compile(
  File "/root/miniconda3/envs/intel-prof/lib/python3.10/site-packages/torch/_dynamo/repro/after_aot.py", line 80, in debug_wrapper
    inner_compiled_fn = compiler_fn(gm, example_inputs)
  File "/root/miniconda3/envs/intel-prof/lib/python3.10/site-packages/torch/_inductor/debug.py", line 228, in inner
    return fn(*args, **kwargs)
  File "/root/miniconda3/envs/intel-prof/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/root/miniconda3/envs/intel-prof/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 54, in newFunction
    return old_func(*args, **kwargs)
  File "/root/miniconda3/envs/intel-prof/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 341, in compile_fx_inner
    compiled_graph: CompiledFxGraph = fx_codegen_and_compile(
  File "/root/miniconda3/envs/intel-prof/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 565, in fx_codegen_and_compile
    compiled_fn = graph.compile_to_fn()
  File "/root/miniconda3/envs/intel-prof/lib/python3.10/site-packages/torch/_inductor/graph.py", line 970, in compile_to_fn
    return self.compile_to_module().call
  File "/root/miniconda3/envs/intel-prof/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 189, in time_wrapper
    r = func(*args, **kwargs)
  File "/root/miniconda3/envs/intel-prof/lib/python3.10/site-packages/torch/_inductor/graph.py", line 941, in compile_to_module
    mod = PyCodeCache.load_by_key_path(key, path, linemap=linemap)
  File "/root/miniconda3/envs/intel-prof/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 1143, in load_by_key_path
    exec(code, mod.__dict__, mod.__dict__)
  File "/tmp/torchinductor_root/i3/ci3z52hmpoqrskqb3lgkd34craii22y7invmpeo2q45sx65qcbqx.py", line 55, in <module>
    async_compile.wait(globals())
  File "/root/miniconda3/envs/intel-prof/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 1431, in wait
    scope[key] = result.result()
  File "/root/miniconda3/envs/intel-prof/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 1290, in result
    self.future.result()
  File "/root/miniconda3/envs/intel-prof/lib/python3.10/concurrent/futures/_base.py", line 458, in result
    return self.__get_result()
  File "/root/miniconda3/envs/intel-prof/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
    raise self._exception
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
TypeError: compile() got an unexpected keyword argument 'signature'

You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True
vlad-penkin commented 1 month ago

@gurwinderintel could you please provide more details on your environment:

gurwinderintel commented 1 month ago

Version: GPU Driver: agama-ci-devel-803.44 torch==2.1.0 intel_extension_for_pytorch==2.1.20+xpu oneAPI BaseKit==2024.2.0.559

alexbaden commented 3 weeks ago

I would recommending using PyTorch 2.4 w/out IPEX - I suspect the combo of PyTorch/IPEX you have is too old.

vlad-penkin commented 3 weeks ago

@gurwinderintel

You are using old LTS driver, the current is 803.75. Triton is functional on LTS but you won't get any performance. We commend to use latest Rolling.

Triton will work only with:

  1. PyTorch Dependency Bundle not regular oneAPI Basekit. Upstream PyTorch and Triton XPU will start supporting oneAPI Basekit starting from the 2025.0 release as per current plans.

  2. Upstream PyTorch >= 2.4 with or without Triton specific pending PR's applied. Regilar IPEX is not supported as of now, we are in a process of deprecating dependency on special PT/IPEX 2.1 test proxies, we do not recommend to use those.

To build Upstream PyTorch from source run the following script.

./scripts/compile-pytorch-ipex.sh --pytorch --upstream-pytorch --source

Our Tutorials code still have import intel_extension_for_pytorch line. You can either comment it out or install the dummy no-op ipex using this script:

from os import chdir, makedirs
from tempfile import TemporaryDirectory
from subprocess import run

with TemporaryDirectory() as tmpdir:
    pkg = "intel_extension_for_pytorch"
    chdir(tmpdir)
    makedirs(pkg, exist_ok=True)
    files = {
        f"{pkg}/__init__.py": "",
        "setup.py": (
            "from setuptools import setup, find_packages\n"
            f"setup(name='{pkg}', version='2', packages=find_packages())"
        ),
        "project.toml": (
            "[build-system]\n"
            "requires = [\"setuptools\", \"wheel\"]\n"
            "build-backend = \"setuptools.build_meta\""
        )
    }
    for file, content in files.items():
        with open(file, "w") as f:
            f.write(content)
    cmds = [
        f"pip uninstall -y {pkg}",
        "pip install build",
        "python -m build .",
        f"pip install dist/{pkg}-2-py3-none-any.whl"
    ]
    for cmd in cmds:
        run(cmd.split(), check=True)