intel / intel-xpu-backend-for-triton

OpenAI Triton backend for Intel® GPUs
MIT License
144 stars 44 forks source link

Query device architecture and feature flags in Triton #1900

Open etiotto opened 3 months ago

etiotto commented 3 months ago

Currently PyTorch passes to Triton the HW flags describing which HW features are available in the target GPU (e.g. DPAS, 2D block operations, etc...). This design requires all potential front-ends to query and pass to the Triton compiler the HW flags and therefore doesn't quite scale well.

A better solution is to query the HW capabilities in the Triton compiler.

etiotto commented 3 months ago

@chengjunlu you self assigned to this ticket. We need to discuss a bit the design of this.

alexbaden commented 3 months ago

The problem with directly accessing the SYCL API in the Triton compiler is we have historically been blocked by a PyTorch upstream "in bad fork" runtime error: https://github.com/pytorch/pytorch/blob/main/torch/xpu/__init__.py#L113 Torch Inductor runs multiple triton.compile calls inside a Python subprocess. Accessing the global Torch XPU context via any of the torch.xpu calls passes through the lazy init routine, which checks to see if the caller is inside a subprocess, and fails. It might be possible to pass the raw SYCL device pointer to triton.compile, but I see two problems:

  1. Now we are relying entirely on SYCL runtime to avoid race conditions, whereas in the current approach PyTorch leverages the GIL to ensure thread safe access to the runtime.
  2. Other backends (AMD, NVIDIA) do not support anything like this - target properties are set by Torch Inductor for the device and passed to triton.compile. Changing our behavior could make upstreaming difficult as it will likely result in divergence from the upstream APIs (unless we do some kind of global init, but that would make the race condition in (1) worse).
chengjunlu commented 3 months ago

The feedback from Eikan about adding the SYCL device pointer to the torch.compile:

He think it is hard to upstream this kind of changes to Pytorch since it is not aligned to AMD and NV.

But he think it is free for Triton to use the SYCL device pointer as input for torch.compile for other framework other than Pytorch.

whitneywhtsang commented 1 month ago

It is not enabled by default, reopening.