intel / intel-xpu-backend-for-triton

OpenAI Triton backend for Intel® GPUs
MIT License
114 stars 30 forks source link

error with use of tensor.dtype.element_ty #629

Closed rgiduthuri-intel closed 2 months ago

rgiduthuri-intel commented 4 months ago

The use of tensor.dtype.element_ty in the below Triton kernel is producing an error (pasted at the end of this message). I'm using the latest llvm-target build as of few hours ago. [The PT 2.1 with CUDA didn't report the error]. Appreciate any quick workaround suggestions. Thanks

import torch
import triton
import triton.language as tl

@triton.jit
def kernel(
      out, e_dst, e_src, adj1, adj2, num_rows,
      WG_SIZE : tl.constexpr
  ):
    zero_indices = tl.zeros((WG_SIZE,), dtype=adj2.dtype.element_ty)
    zero_values = tl.zeros((WG_SIZE,), dtype=e_src.dtype.element_ty)

if __name__ == '__main__':
    if torch.cuda.device_count() > 0:
        device = torch.device('cuda')
    else:
        import intel_extension_for_pytorch
        device = torch.device('xpu')

    num_rows, num_cols, num_indices, heads = 16, 16, 16, 4
    e_src = torch.randn((num_cols,heads), dtype=torch.float32, device=device) + 0.5
    e_dst = torch.randn((num_rows,heads), dtype=torch.float32, device=device) + 0.5
    adj1 = torch.zeros(num_rows + 1, dtype=torch.int32, device=device)
    adj2 = torch.randint(num_rows, (num_indices,), dtype=torch.int32, device=device)
    out = torch.empty((num_indices,heads), dtype=e_src.dtype, device=e_src.device)
    # invoke triton kernel
    launch_grid = (num_rows * heads // 256,)
    kernel[launch_grid](out, e_dst, e_src, adj1, adj2, num_rows, 256)

Here's the error message:

  File "/home/rgiduthu/miniconda3/lib/python3.11/ast.py", line 418, in visit
    return visitor(node)
           ^^^^^^^^^^^^^
  File "/home/rgiduthu/triton-dev-llvm-target/intel-xpu-backend-for-triton/python/triton/runtime/jit.py", line 47, in visit_Attribute
    return getattr(lhs, node.attr)
           ^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'torch.dtype' object has no attribute 'element_ty'
Sarbojit2019 commented 4 months ago

Hello @rgiduthuri-intel,

I have made minor modification in your kernel code and with that sample code runs successfully. While I am trying to root cause the issue I think this workaround will unblock you. ========= Original code =========== @triton.jit def kernel( out, e_dst, e_src, adj1, adj2, num_rows, WG_SIZE : tl.constexpr ): zero_indices = tl.zeros((WG_SIZE,), dtype=adj2.dtype.element_ty) zero_values = tl.zeros((WG_SIZE,), dtype=e_src.dtype.element_ty)

============= Modified code ========== @triton.jit def kernel( out_, edst, esrc, adj1, adj2, numrows, WG_SIZE : tl.constexpr ): zero_indices = tl.zeros((WGSIZE,), dtype=adj2.dtype.element_ty) zero_values = tl.zeros((WG_SIZE,), dtype=esrc.dtype.element_ty)

To me looks like it is name conflict issue as change in kernel parameter names (notice the additional '_' to each parameters) resolved the error.

Sarbojit2019 commented 4 months ago

@rgiduthuri-intel, I did further investigation and found even with latest Triton (branch:main) test fails on NVidia machine. You might have installed using pip which installs Triton 2.2.1 hence you got the success. Since intel-xpu-backend-for-triton has got latest triton source hence you see the failure here.

Conclusion: It is not a Intel Triton issue but it is Triton 3.0 issue. @chengjunlu has already raised the fix . Until fix gets merged and unstreamed please use the workaround I have suggested.

Thanks

rgiduthuri-intel commented 4 months ago

Great! Thanks for the workaround. Please feel free to close this issue.

whitneywhtsang commented 4 months ago

@Sarbojit2019 Posted a PR in OpenAI repo to fix this issue: https://github.com/openai/triton/pull/3383