Open srush opened 1 year ago
I might have missed some details, but you could try to rerun a triton kernel with the grid size/block size you would like and then export the ptx.
Yes, I ran the Triton kernel with BLOCK_SIZE=1024 (as shown above), but the asm["ptx"] that it produces still has .maxntid 128, 1, 1
. Am I doing something wrong? Should BLOCK_SIZE in triton correspond to CUDA blocksize directly? (I don't know how triton translates the code to cuda).
Should BLOCK_SIZE in triton correspond to CUDA blocksize directly?
No
num_warps x warp_size = cta_size
, on nvidia gpus, warp_size=32
Hmm, I'm confused. So if I want to run the output PTX from triton, that was originally block_size 1024, should I run CUDA blocks of 1024 / num_warps? What do I set shared memory to?
In general where should I look to see how Triton calls the kernerls internally?
BLOCK_SIZE
is not equivalent as the number of threads of a CUDA block, it only specifies how many elements will be handled by each block.
Great, so that answers half my question. It sounds like CTA here for cuda should be 128 and that corresponds to 32 TpW * 4 WpCTA.
But I still don't know how to call this as a CUDA function? What values should I pass in for GRID, BLOCK, Shared memory?
(TTGIR)
#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
func public @gptq_kernel_0d1d2d3d4d5d(%arg0: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) {
%cst = arith.constant dense<8> : tensor<1024xi32, #blocked>
Oh, I think I get it now.
If I have a Triton BLOCK_SIZE of 1024, I should still use <<TOTAL/ BLOCKSIZE, 128, shared_mem>> in CUDA and it will correspond to the original Triton function. 128 = 32 * num_warps.
This worked for me for simple functions, but now it seems to be failing for more complex functions. I'm getting weird bugs with alignment and outputs.
Besides grid and shared memory are there any other settings I need to export to call the PTX code outside of python? I noticed for instance the number of args in the PTX sometimes differs from the original function (it seems to stop at 12). Is there a better way to use Triton outside of python?
You can check the AOT module.
@srush
Wrote some tools for exactly this purpose. See here and here.
Reason for the differing function signature in the compiled ptx / cubin
from original is that some of the non-tensor args are treated as compile-time constants (even those not marked as tl.constexpr
) and no longer part of the final sig. The metadata generated during the compilation pipeline, specifically, the specializations should shed more light.
The PR per above handles these cases automatically.
Happy to adapt these tools for your use case.
Hi, this trick for exporting PTX of a JIT'd kernel no longer works, the organization of cache etc seems to have changed since summer 2023.
What is now the correct method for accessing PTX / other compiled kernel format?
Hi, this trick for exporting PTX of a JIT'd kernel no longer works, the organization of cache etc seems to have changed since summer 2023.
What is now the correct method for accessing PTX / other compiled kernel format?
@wbrickner I was actually able to export the generated PTX using the trick mentioned above on Google Colab for T4 GPU with PyTorch Version: 2.4.1+cu121 and Triton version 3.1.0 (Oct 14, 2024).
# Access the cache of compiled kernels
kernel_cache = your_triton_kernel.cache # You need to launch the kernel first so that the PTX is generated in cache.
# Get the first cached compilation (or iterate over all)
cache_entry = kernel_cache[0]
compiled_kernel = list(cache_entry.values())[0]
# Access the 'asm' dictionary
asm_dict = compiled_kernel.asm
# List all keys in the 'asm' dictionary
all_keys = asm_dict.keys()
print("All keys in asm dictionary:", all_keys) # All keys in asm dictionary: dict_keys(['ttir', 'ttgir', 'llir', 'ptx', 'cubin'])
with open("your_triton_kernel.ptx", "w") as a:
print(asm_dict['ptx'], file=a)
Hi,
I am trying to write a triton kernel that I can then load in as PTX. Everything works fine if I just dump the PTX to file and load it in.
a = list(kernel.cache[0].values())[0]
As long as my # of threads is 128 and I have shared memory 512
func<<<grid_size, 128, 512, stream>>>
However if I run the python code with a bigger block_size, I cannot seem to run it from PTX,
func<<<grid_size, 1024, 512, stream>>> // fails invalid input
I looked at the PTX code and I see this line which I guess explains the issue.
.maxntid 128, 1, 1
But then if I try to run with 128 block then I get
func<<<grid_size, 128, 512, stream>>> // fails invalid access
I guess this is because the shared memory is wrong? But .shared is still 512 in the code cache.
Any tips for exporting PTX which bigger blocks?
Here's my code if it helps;