intel / intel-xpu-backend-for-triton

OpenAI Triton backend for Intel® GPUs
MIT License
124 stars 35 forks source link

Trying to run vector addition example with latest v3.0.0b2 released wheel #1110

Closed fcharras closed 4 months ago

fcharras commented 4 months ago

After installing triton v.3.0.0b2 from the latest wheel released at https://github.com/intel/intel-xpu-backend-for-triton/releases, I'm trying to run this example from triton-lang.org, slightly changed so that it expects xpu devices rather than cuda devices:

Full example snippet ```python import intel_extension_for_pytorch import torch import triton import triton.language as tl @triton.jit def add_kernel(x_ptr, # *Pointer* to first input vector. y_ptr, # *Pointer* to second input vector. output_ptr, # *Pointer* to output vector. n_elements, # Size of the vector. BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. # NOTE: `constexpr` so it can be used as a shape value. ): # There are multiple 'programs' processing different data. We identify which program # we are here: pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. # This program will process inputs that are offset from the initial data. # For instance, if you had a vector of length 256 and block_size of 64, the programs # would each access the elements [0:64, 64:128, 128:192, 192:256]. # Note that offsets is a list of pointers: block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) # Create a mask to guard memory operations against out-of-bounds accesses. mask = offsets < n_elements # Load x and y from DRAM, masking out any extra elements in case the input is not a # multiple of the block size. x = tl.load(x_ptr + offsets, mask=mask) y = tl.load(y_ptr + offsets, mask=mask) output = x + y # Write x + y back to DRAM. tl.store(output_ptr + offsets, output, mask=mask) def add(x: torch.Tensor, y: torch.Tensor): # We need to preallocate the output. output = torch.empty_like(x) assert x.is_xpu and y.is_xpu and output.is_xpu n_elements = output.numel() # The SPMD launch grid denotes the number of kernel instances that run in parallel. # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]. # In this case, we use a 1D grid where the size is the number of blocks: grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) # NOTE: # - Each torch.tensor object is implicitly converted into a pointer to its first element. # - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel. # - Don't forget to pass meta-parameters as keywords arguments. add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) # We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still # running asynchronously at this point. return output torch.manual_seed(0) size = 98432 x = torch.rand(size, device='xpu') y = torch.rand(size, device='xpu') output_torch = x + y output_triton = add(x, y) print(output_torch) print(output_triton) print(f'The maximum difference between torch and triton is ' f'{torch.max(torch.abs(output_torch - output_triton))}') ```

it fails with the following traceback:

Traceback ``` Traceback (most recent call last): File "/home/uaf04777df9ce9aeb441c203bc88f3d1/./script.py", line 59, in output_triton = add(x, y) ^^^^^^^^^ File "/home/uaf04777df9ce9aeb441c203bc88f3d1/./script.py", line 48, in add add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) File "/home/uaf04777df9ce9aeb441c203bc88f3d1/miniforge3/envs/triton/lib/python3.11/site-packages/triton/runtime/jit.py", line 209, in return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/uaf04777df9ce9aeb441c203bc88f3d1/miniforge3/envs/triton/lib/python3.11/site-packages/triton/runtime/jit.py", line 471, in run device = driver.active.get_current_device() ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/uaf04777df9ce9aeb441c203bc88f3d1/miniforge3/envs/triton/lib/python3.11/site-packages/triton/backends/intel/driver.py", line 405, in get_current_device return self.utils.get_current_device() ^^^^^^^^^^ File "/home/uaf04777df9ce9aeb441c203bc88f3d1/miniforge3/envs/triton/lib/python3.11/site-packages/triton/backends/intel/driver.py", line 399, in __getattr__ self.utils = XPUUtils() ^^^^^^^^^^ File "/home/uaf04777df9ce9aeb441c203bc88f3d1/miniforge3/envs/triton/lib/python3.11/site-packages/triton/backends/intel/driver.py", line 57, in __init__ self.context = mod.init_context(self.get_sycl_queue()) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ TypeError: an integer is required ```

(nb: the TypeError is returned by mod.init_context, the self.get_sycl_queue() call works and returns a capsule object with a sycl queue.)

Here are details about my environment:

environment ``` Collecting environment information... PyTorch version: 2.1.0.post2+cxx11.abi PyTorch CXX11 ABI: Yes IPEX version: 2.1.30+xpu IPEX commit: 474a6b3cb Build type: Release OS: Ubuntu 22.04.4 LTS (x86_64) GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 Clang version: N/A IGC version: 2024.1.0 (2024.1.0.20240308) CMake version: version 3.22.1 Libc version: glibc-2.35 Python version: 3.11.9 | packaged by conda-forge | (main, Apr 19 2024, 18:36:13) [GCC 12.3.0] (64-bit runtime) Python platform: Linux-5.15.0-106-generic-x86_64-with-glibc2.35 Is XPU available: True DPCPP runtime version: 2024.1 MKL version: 2024.1 GPU models and configuration: [0] _DeviceProperties(name='Intel(R) Data Center GPU Max 1100', platform_name='Intel(R) Level-Zero', dev_type='gpu', driver_version='1.3.28202', has_fp64=1, total_memory=49152MB, max_compute_units=448, gpu_eu_count=448) Intel OpenCL ICD version: 23.52.28202.51-821~22.04 Level Zero version: 1.3.28202.51-821~22.04 CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 52 bits physical, 57 bits virtual Byte Order: Little Endian CPU(s): 224 On-line CPU(s) list: 0-223 Vendor ID: GenuineIntel Model name: Intel(R) Xeon(R) Platinum 8480+ CPU family: 6 Model: 143 Thread(s) per core: 2 Core(s) per socket: 56 Socket(s): 2 Stepping: 8 CPU max MHz: 3800.0000 CPU min MHz: 800.0000 BogoMIPS: 4000.00 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cat_l2 cdp_l3 invpcid_single intel_ppin cdp_l2 ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect avx_vnni avx512_bf16 wbnoinvd dtherm ida arat pln pts hwp hwp_act_window hwp_epp hwp_pkg_req avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b enqcmd fsrm md_clear serialize tsxldtrk pconfig arch_lbr amx_bf16 avx512_fp16 amx_tile amx_int8 flush_l1d arch_capabilities Virtualization: VT-x L1d cache: 5.3 MiB (112 instances) L1i cache: 3.5 MiB (112 instances) L2 cache: 224 MiB (112 instances) L3 cache: 210 MiB (2 instances) NUMA node(s): 2 NUMA node0 CPU(s): 0-55,112-167 NUMA node1 CPU(s): 56-111,168-223 Vulnerability Gather data sampling: Not affected Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Retbleed: Not affected Vulnerability Spec rstack overflow: Not affected Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Enhanced IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI BHI_DIS_S Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected Versions of relevant libraries: [pip3] intel-extension-for-pytorch==2.1.30+xpu [pip3] numpy==1.26.4 [pip3] torch==2.1.0.post2+cxx11.abi [conda] intel-extension-for-pytorch 2.1.30 py311_xpu_0 intel [conda] mkl-include 2024.1.0 intel_691 intel [conda] mkl-static 2024.1.0 intel_691 intel [conda] numpy 1.26.4 py311h64a7726_0 conda-forge [conda] pytorch 2.1.0 py311_xpu_3 intel ```
fcharras commented 4 months ago

Seems related to https://github.com/intel/intel-xpu-backend-for-triton/issues/973

pbchekin commented 4 months ago

Currently our build of Triton is sensitive to the versions of PyTorch and IPEX, try to install them with the pre-built wheels from https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/9071023260 (attached as artifacts).

fcharras commented 4 months ago

The nightlies are awesome ! thank you for the hint, will try and report the results with those.

fcharras commented 4 months ago

Yes, it's working with the latest nightly.