Open vlad-penkin opened 3 months ago
Seems we can retrieve native binary from level zero module using zeModuleGetNativeBinary
.
Yes, assuming they work L0 has APIs we should be able to leverage.
The difference is in the way the compiler/driver works. For NVIDIA, they can call ptxas
to assemble the PTX to cubin and then pass the cubin to their runtime. For us, we actually compile the spirv to machine code during the driver stage. So, I need to either lift the compilation of spirv to native binary out of driver and into compiler, or find a way to get the paths to the driver without breaking triton layering.
I wanted to look into this to see if it could be related to #1721, but the numbers don't quite match so I suppose I am not optimistic. Still, this could be a nice win for us as compilation can be 100-300ms, especially if there are register spills and we recompile.
I think this maybe the solution:
we can retrieve native binary from level zero module using zeModuleGetNativeBinary. And here is the example: https://github.com/oneapi-src/oneDNN/blob/2e7b691217ff17497aebd7e565fa1701f8a42396/src/gpu/intel/sycl/utils.cpp#L211
Then to reconstruct the L0 model in deployment, create level zero module by set the ze_module_format_t as ZE_MODULE_FORMAT_NATIVE in zeModuleCreate
Here is the example: https://github.com/oneapi-src/oneDNN/blob/2e7b691217ff17497aebd7e565fa1701f8a42396/src/gpu/intel/sycl/l0/utils.cpp#L184
I have a prototype working. The level zero APIs are the easy part - we have to make significant changes to our triton compilation flow to fit this into Triton's architecture. Fortunately, I think I can adjust the compilation flow while preserving the existing Triton layering. I will clean up my prototype and post it as a draft PR for review tomorrow.
Blocked by
There is a plan to enable AOT Inductor for Intel GPU in PyTorch 2.6. While working on the design, PyTorch Team realized that the Triton kernel is now saved as SPIR-V(IR), while CUDA is cubin(device code binary) which will affect E2E performance:
PyTorch Team is asking whether Triton can save the compiled kernel as device binary code, and load it with L0 runtime.