Open etiotto opened 3 months ago
@chengjunlu you self assigned to this ticket. We need to discuss a bit the design of this.
The problem with directly accessing the SYCL API in the Triton compiler is we have historically been blocked by a PyTorch upstream "in bad fork" runtime error: https://github.com/pytorch/pytorch/blob/main/torch/xpu/__init__.py#L113
Torch Inductor runs multiple triton.compile
calls inside a Python subprocess. Accessing the global Torch XPU context via any of the torch.xpu
calls passes through the lazy init routine, which checks to see if the caller is inside a subprocess, and fails.
It might be possible to pass the raw SYCL device pointer to triton.compile
, but I see two problems:
triton.compile
. Changing our behavior could make upstreaming difficult as it will likely result in divergence from the upstream APIs (unless we do some kind of global init, but that would make the race condition in (1) worse). The feedback from Eikan about adding the SYCL device pointer to the torch.compile
:
He think it is hard to upstream this kind of changes to Pytorch since it is not aligned to AMD and NV.
But he think it is free for Triton to use the SYCL device pointer as input for torch.compile
for other framework other than Pytorch.
It is not enabled by default, reopening.
Currently PyTorch passes to Triton the HW flags describing which HW features are available in the target GPU (e.g. DPAS, 2D block operations, etc...). This design requires all potential front-ends to query and pass to the Triton compiler the HW flags and therefore doesn't quite scale well.
A better solution is to query the HW capabilities in the Triton compiler.