intel / intel-xpu-backend-for-triton

OpenAI Triton backend for Intel® GPUs
MIT License
131 stars 39 forks source link

Triton tutorial example 03-matrix-multiplication.py failed #706

Closed rgiduthuri-intel closed 2 months ago

rgiduthuri-intel commented 6 months ago

I'm not able to run Triton GEMM tutorial example using intel-xpu-backend-for-triton. Below are the local changes done to "triton/python/tutorials/03-matrix-multiplication.py" to use XPU. Please scroll down to see the error messages from the latest commit as well as few week's old. Strangely different kind of errors.

Appreciate a quick workaround if possible. Thanks.

% git diff
diff --git a/python/tutorials/03-matrix-multiplication.py b/python/tutorials/03-matrix-multiplication.py
index d2b9b65b0..3cababc26 100644
--- a/python/tutorials/03-matrix-multiplication.py
+++ b/python/tutorials/03-matrix-multiplication.py
@@ -300,9 +300,16 @@ def matmul(a, b, activation=""):
 #
 # We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS).

+# pick device
+if torch.cuda.device_count() > 0:
+    device = torch.device('cuda')
+else:
+    import intel_extension_for_pytorch
+    device = torch.device('xpu')
+
 torch.manual_seed(0)
-a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
-b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
+a = torch.randn((512, 512), device=device, dtype=torch.float16)
+b = torch.randn((512, 512), device=device, dtype=torch.float16)
 triton_output = matmul(a, b)
 torch_output = torch.matmul(a, b)
 print(f"triton_output={triton_output}")
@@ -339,8 +346,8 @@ else:
         args={},
     ))
 def benchmark(M, N, K, provider):
-    a = torch.randn((M, K), device='cuda', dtype=torch.float16)
-    b = torch.randn((K, N), device='cuda', dtype=torch.float16)
+    a = torch.randn((M, K), device=device, dtype=torch.float16)
+    b = torch.randn((K, N), device=device, dtype=torch.float16)
     quantiles = [0.5, 0.2, 0.8]
     if provider == 'cublas':
         ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles)

with intel-xpu-backend-for-triton latest commit 0469c4053cfefe7957f2127f7b41900baf934c4a

% python 03-matrix-multiplication.py
Traceback (most recent call last):
  File "/home/rgiduthu/triton/triton/python/tutorials/03-matrix-multiplication.py", line 313, in <module>
    triton_output = matmul(a, b)
                    ^^^^^^^^^^^^
  File "/home/rgiduthu/triton/triton/python/tutorials/03-matrix-multiplication.py", line 286, in matmul
    matmul_kernel[grid](
  File "/home/rgiduthu/triton-dev-llvm-target/intel-xpu-backend-for-triton/python/triton/runtime/jit.py", line 180, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rgiduthu/triton-dev-llvm-target/intel-xpu-backend-for-triton/python/triton/runtime/autotuner.py", line 141, in run
    timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rgiduthu/triton-dev-llvm-target/intel-xpu-backend-for-triton/python/triton/runtime/autotuner.py", line 141, in <dictcomp>
    timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rgiduthu/triton-dev-llvm-target/intel-xpu-backend-for-triton/python/triton/runtime/autotuner.py", line 120, in _bench
    return do_bench(kernel_call, warmup=self.num_warmups, rep=self.num_reps, quantiles=(0.5, 0.2, 0.8))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rgiduthu/triton-dev-llvm-target/intel-xpu-backend-for-triton/python/triton/testing.py", line 110, in do_bench
    fn()
  File "/home/rgiduthu/triton-dev-llvm-target/intel-xpu-backend-for-triton/python/triton/runtime/autotuner.py", line 110, in kernel_call
    self.fn.run(
  File "/home/rgiduthu/triton-dev-llvm-target/intel-xpu-backend-for-triton/python/triton/runtime/jit.py", line 361, in run
    device = driver.active.get_current_device()
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rgiduthu/triton-dev-llvm-target/intel-xpu-backend-for-triton/python/triton/backends/intel/driver.py", line 395, in get_current_device
    return self.utils.get_current_device()
           ^^^^^^^^^^
  File "/home/rgiduthu/triton-dev-llvm-target/intel-xpu-backend-for-triton/python/triton/backends/intel/driver.py", line 389, in __getattr__
    self.utils = XPUUtils()
                 ^^^^^^^^^^
  File "/home/rgiduthu/triton-dev-llvm-target/intel-xpu-backend-for-triton/python/triton/backends/intel/driver.py", line 56, in __init__
    self.context = mod.init_context(self.get_sycl_queue())
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: an integer is required
%

with intel-xpu-backend-for-triton commit from few weeks back (don't have the exact commit)

% python 03-matrix-multiplication.py
L0 build module failed. Log:
error: total scratch space exceeds HW supported limit for kernel matmul_kernel_0d1d2d3d4d5d6d7c8d9c10d11c: 339200 bytes (max permitted PTSS 262144 bytes)
error: backend compiler failed build.

LIBXSMM_VERSION: main_stable-1.17-3651 (25693763)LIBXSMM WARNING: AMX state allocation in the OS failed!

LIBXSMM_TARGET: clx [Genuine Intel(R) CPU 0000%@]
Registry and code: 13 MB
Command: python 03-matrix-multiplication.py
Uptime: 9.757146 s
Segmentation fault (core dumped)
%
etiotto commented 6 months ago

This issue is caused by environment problems. In order to run Triton end-to-end tests a new version of IPEX is required. Currently we build IPEX and PyTorch from source as part of the Triton build (See scripts/compile-pytorch-ipex.sh). That script will build and install PyTorch and IPEX source code and build them. Note that one has to uninstall existing version (pip uninstall ...) or the script will not replace them with the newly built version.

In my environment I have:

I can run the tutorial correctly (using latest Triton commit).

@rgiduthuri-intel can you please give it a try and report back on whether it solves the problem you are facing please ?

pbchekin commented 6 months ago

@rgiduthuri-intel, we also build torch, ipex, and triton wheels nightly and attach them as artifacts to the corresponding job. This is the latest run, for example: https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/8334844255. You can download the artifact for your Python version and install all wheels with pip install *.whl. Let us know if that works for you.

rgiduthuri-intel commented 6 months ago

@etiotto: now I'm able to run with a clean build using both scripts/compile-triton.sh --env and scripts/compile-pytorch-ipex.sh. @pbchekin I will try downloading .whl going forward. Thank you.

Now that the build problem is resolved, I'm back to crash with 03-matrix-multiplication.py that I mentioned in the original message with an older build. Appreciate your help.

$ python 03-matrix-multiplication.py
L0 build module failed. Log:
error: total scratch space exceeds HW supported limit for kernel matmul_kernel_0d1d2d3d4d5d6d7c8d9c10d11c: 338112 bytes (max permitted PTSS 262144 bytes)
error: backend compiler failed build.

LIBXSMM_VERSION: main_stable-1.17-3651 (25693763)LIBXSMM WARNING: AMX state allocation in the OS failed!

LIBXSMM_TARGET: clx [Genuine Intel(R) CPU 0000%@]
Registry and code: 13 MB
Command: python 03-matrix-multiplication.py
Uptime: 15.185146 s
Segmentation fault (core dumped)
etiotto commented 6 months ago

@rgiduthuri-intel so using the latest Triton version and the pinned version of pytorch and IPEX you are able to run the tutorial. Correct ?

I do not know how to reproduce the issue you have with the older Triton build. In general we aim to fix problems that are reproducible using the latest version of Triton. Fixing older builds is not something we can do.

Is possible your older binary fails to run because the IPEX (or PyTorch) version is not the one required. The interface between pytorch/IPEX and Triton has to be in sync.

@pbchekin do we have older version of Triton saved up somewhere ?

pbchekin commented 6 months ago

@rgiduthuri-intel, do you have the same issue with the latest pytorch, ipex, triton wheels? Looking to the error message:

total scratch space exceeds HW supported limit

I would say there is a hardware limitation. What GPU do you have?

whitneywhtsang commented 6 months ago

It is a known issue, you can comment out the problematic configuration, like how it is done in https://github.com/intel/intel-xpu-backend-for-triton/blob/llvm-target/python/tutorials/03-matrix-multiplication.py#L167.

@@ -163,8 +164,9 @@ import triton.language as tl
 #       provided configs
 @triton.autotune(
     configs=[
-        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3,
-                      num_warps=8),
+        # FIXME: Once tl.dot uses DPAS put back the workload commented out.
+        # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3,
+        #               num_warps=8),
vlad-penkin commented 2 months ago

@rgiduthuri-intel can we close this ticket?

rgiduthuri-intel commented 2 months ago

Please close it if it's taken care of. Thanks