triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.09k stars 1.6k forks source link

Exporting triton to PTX #2166

Open srush opened 1 year ago

srush commented 1 year ago

Hi,

I am trying to write a triton kernel that I can then load in as PTX. Everything works fine if I just dump the PTX to file and load it in.

a = list(kernel.cache[0].values())[0]

As long as my # of threads is 128 and I have shared memory 512

func<<<grid_size, 128, 512, stream>>>

However if I run the python code with a bigger block_size, I cannot seem to run it from PTX,

func<<<grid_size, 1024, 512, stream>>> // fails invalid input

I looked at the PTX code and I see this line which I guess explains the issue.

.maxntid 128, 1, 1

But then if I try to run with 128 block then I get

func<<<grid_size, 128, 512, stream>>> // fails invalid access

I guess this is because the shared memory is wrong? But .shared is still 512 in the code cache.

Any tips for exporting PTX which bigger blocks?

Here's my code if it helps;

import torch                                                                                                                                                                                                                                                                                                                   

import triton                                                                                                                                                                                                                                                                                                                  
import triton.language as tl                                                                                                                                                                                                                                                                                                   

@triton.jit                                                                                                                                                                                                                                                                                                                    
def gptq_kernel(                                                                                                                                                                                                                                                                                                               
    qweight_ptr,  # *Pointer* to qweight int input matrix.                                                                                                                                                                                                                                                                     
    qscale_ptr,  # *Pointer* to qsscale f16 input matrix.                                                                                                                                                                                                                                                                      
    qzeros_ptr,  # *Pointer* to qzeros int input matrix.                                                                                                                                                                                                                                                                       
    x_ptr, # *Pointer* to x input vector.                                                                                                                                                                                                                                                                                      
    y_ptr, # *Pointer* to y output vector.                                                                                                                                                                                                                                                                                     
    IN,                                                                                                                                                                                                                                                                                                                        
    BLOCK_SIZE: tl.constexpr):                                                                                                                                                                                                                                                                                                 

    pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.                                                                                                                                                                                                                                                        
    block_start = pid * BLOCK_SIZE                                                                                                                                                                                                                                                                                             
    block = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)                                                                                                                                                                                                                                                                        
    group_size: tl.constexpr = 128                                                                                                                                                                                                                                                                                             
    BITS = 4                                                                                                                                                                                                                                                                                                                   
    mask = 2**4-1                                                                                                                                                                                                                                                                                                              

    in_elem = block % IN                                                                                                                                                                                                                                                                                                       
    out_elem = block // IN                                                                                                                                                                                                                                                                                                     

    # x is just group size                                                                                                                                                                                                                                                                                                     
    x = tl.load(x_ptr + in_elem)                                                                                                                                                                                                                                                                                               

    # scale is taken from the current group                                                                                                                                                                                                                                                                                    
    group = block // group_size                                                                                                                                                                                                                                                                                                
    scale = tl.load(qscale_ptr + group)                                                                                                                                                                                                                                                                                        

    # zeros are taken repeat groupsize times                                                                                                                                                                                                                                                                                   
    zeros = tl.load(qzeros_ptr + (in_elem * out_elem // 8 ) // 128)                                                                                                                                                                                                                                                            
    zero_shift = (out_elem % 8) * BITS                                                                                                                                                                                                                                                                                         
    zeros = ((zeros >> zero_shift) & mask) + 1                                                                                                                                                                                                                                                                                 

    # Compute                                                                                                                                                                                                                                                                                                                  
    offset = (block % 8) * BITS                                                                                                                                                                                                                                                                                                
    splat = tl.load(qweight_ptr + (block // 8))                                                                                                                                                                                                                                                                                
    vals = (splat >> offset) & mask                                                                                                                                                                                                                                                                                            
    out = scale * (vals - zeros) * x                                                                                                                                                                                                                                                                                           
    tl.store(y_ptr + pid, tl.sum(out, axis=0))                                                                                                                                                                                                                                                                                 

IN = 4096                                                                                                                                                                                                                                                                                                                      
OUT = 4096                                                                                                                                                                                                                                                                                                                     
GROUPSIZE = 128                                                                                                                                                                                                                                                                                                                

q_weights = torch.tensor([[1985229328] * (IN // 8)] * OUT).int().cuda()                                                                                                                                                                                                                                                        
q_scale = torch.tensor([[10.] * (IN // GROUPSIZE)] * OUT).float().cuda()                                                                                                                                                                                                                                                       
q_zeros = torch.tensor([[1985229328] * (IN // GROUPSIZE)] * (OUT // 8)).int().cuda()                                                                                                                                                                                                                                           
x = torch.tensor([[10.] * IN]).cuda()                                                                                                                                                                                                                                                                                          
q_out = torch.zeros(OUT).float().cuda()                                                                                                                                                                                                                                                                                        
print(q_weights.shape)                                                                                                                                                                                                                                                                                                         
print(q_scale.shape)                                                                                                                                                                                                                                                                                                           
print(q_zeros.shape)                                                                                                                                                                                                                                                                                                           

n_elements = OUT * IN                                                                                                                                                                                                                                                                                                          
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)                                                                                                                                                                                                                                                             
gptq_kernel[grid](q_weights, q_scale, q_zeros, x, q_out, IN, BLOCK_SIZE=1024)                                                                                                                                                                                                                                                  
print(q_out)                                                                                                                                                                                                                                                                                                                   

print(dir(gptq_kernel.cache))                                                                                                                                                                                                                                                                                                  
with open("gptq.ptx", "w") as a:                                                                                                                                                                                                                                                                                               
    print(list(gptq_kernel.cache[0].values())[0].asm['ptx'], file=a) 
//                                                                                                                                                                                                                                                                                                                             
// Generated by LLVM NVPTX Back-End                                                                                                                                                                                                                                                                                            
//                                                                                                                                                                                                                                                                                                                             

.version 7.5                                                                                                                                                                                                                                                                                                                   
.target sm_75                                                                                                                                                                                                                                                                                                                  
.address_size 64                                                                                                                                                                                                                                                                                                               

    // .globl   gptq_kernel_0d1d2d3d4d5d                                                                                                                                                                                                                                                                                       
.extern .shared .align 1 .b8 global_smem[];                                                                                                                                                                                                                                                                                    

.visible .entry gptq_kernel_0d1d2d3d4d5d(                                                                                                                                                                                                                                                                                      
    .param .u64 gptq_kernel_0d1d2d3d4d5d_param_0,                                                                                                                                                                                                                                                                              
    .param .u64 gptq_kernel_0d1d2d3d4d5d_param_1,                                                                                                                                                                                                                                                                              
    .param .u64 gptq_kernel_0d1d2d3d4d5d_param_2,                                                                                                                                                                                                                                                                              
    .param .u64 gptq_kernel_0d1d2d3d4d5d_param_3,                                                                                                                                                                                                                                                                              
    .param .u64 gptq_kernel_0d1d2d3d4d5d_param_4,                                                                                                                                                                                                                                                                              
    .param .u32 gptq_kernel_0d1d2d3d4d5d_param_5                                                                                                                                                                                                                                                                               
)                                                                                                                                                                                                                                                                                                                              
.maxntid 128, 1, 1                                                                                                                                                                                                                                                                                                             
{                                                                                                                                                                                                                                                                                                                              
    .reg .pred  %p<32>;                                                                                                                                                                                                                                                                                                        
    .reg .b32   %r<302>;                                                                                                                                                                                                                                                                                                       
    .reg .f32   %f<56>;                                                                                                                                                                                                                                                                                                        
    .reg .b64   %rd<60>;                                                                                                                                                                                                                                                                                                       

    ld.param.u64    %rd28, [gptq_kernel_0d1d2d3d4d5d_param_0];                                                                                                                                                                                                                                                                 
    ld.param.u64    %rd29, [gptq_kernel_0d1d2d3d4d5d_param_1];                                                                                                                                                                                                                                                                 
    mov.u32     %r52, %tid.x;                                                                                                                                                                                                                                                                                                  
    and.b32     %r53, %r52, 31;                                                                                                                                                                                                                                                                                                
    ld.param.u64    %rd30, [gptq_kernel_0d1d2d3d4d5d_param_2];                                                                                                                                                                                                                                                                 
    shr.u32     %r54, %r52, 3;                                                                                                                                                                                                                                                                                                 
    and.b32     %r55, %r54, 536870908;                                                                                                                                                                                                                                                                                         
    ld.param.u64    %rd31, [gptq_kernel_0d1d2d3d4d5d_param_3];                                                                                                                                                                                                                                                                 
    shl.b32     %r56, %r52, 2;                                                                                                                                                                                                                                                                                                 
    ld.param.u64    %rd32, [gptq_kernel_0d1d2d3d4d5d_param_4];                                                                                                                                                                                                                                                                 
    and.b32     %r57, %r56, 1020;                                                                                                                                                                                                                                                                                              
    ld.param.u32    %r58, [gptq_kernel_0d1d2d3d4d5d_param_5];                                                                                                                                                                                                                                                                  
    mov.u32     %r59, %ctaid.x;                                                                                                                                                                                                                                                                                                
    shl.b32     %r60, %r59, 10;                                                                                                                                                                                                                                                                                                
    or.b32      %r61, %r57, %r60;                                                                                                                                                                                                                                                                                              
    or.b32      %r62, %r61, 1;                                                                                                                                                                                                                                                                                                 
    or.b32      %r63, %r61, 3;                                                                                                                                                                                                                                                                                                 
    or.b32      %r64, %r61, 2;                                                                                                                                                                                                                                                                                                 
    add.s32     %r65, %r61, 513;                                                                                                                                                                                                                                                                                               
    add.s32     %r66, %r61, 512;                                                                                                                                                                                                                                                                                               
    add.s32     %r67, %r61, 515;                                                                                                                                                                                                                                                                                               
    add.s32     %r68, %r61, 514;                                                                                                                                                                                                                                                                                               
    div.s32     %r69, %r61, %r58;                                                                                                                                                                                                                                                                                              
    mul.lo.s32  %r70, %r69, %r58;                                                                                                                                                                                                                                                                                              
    sub.s32     %r71, %r61, %r70;                                                                                                                                                                                                                                                                                              
    div.s32     %r72, %r62, %r58;                                                                                                                                                                                                                                                                                              
    mul.lo.s32  %r73, %r72, %r58;                                                                                                                                                                                                                                                                                              
    sub.s32     %r74, %r62, %r73;                                                                                                                                                                                                                                                                                              
    div.s32     %r75, %r64, %r58;                                                                                                                                                                                                                                                                                              
    mul.lo.s32  %r76, %r75, %r58;                                                                                                                                                                                                                                                                                              
    sub.s32     %r77, %r64, %r76;       
Jokeren commented 1 year ago

I might have missed some details, but you could try to rerun a triton kernel with the grid size/block size you would like and then export the ptx.

srush commented 1 year ago

Yes, I ran the Triton kernel with BLOCK_SIZE=1024 (as shown above), but the asm["ptx"] that it produces still has .maxntid 128, 1, 1. Am I doing something wrong? Should BLOCK_SIZE in triton correspond to CUDA blocksize directly? (I don't know how triton translates the code to cuda).

Jokeren commented 1 year ago

Should BLOCK_SIZE in triton correspond to CUDA blocksize directly?

No

Jokeren commented 1 year ago

num_warps x warp_size = cta_size, on nvidia gpus, warp_size=32

srush commented 1 year ago

Hmm, I'm confused. So if I want to run the output PTX from triton, that was originally block_size 1024, should I run CUDA blocks of 1024 / num_warps? What do I set shared memory to?

In general where should I look to see how Triton calls the kernerls internally?

Jokeren commented 1 year ago

BLOCK_SIZE is not equivalent as the number of threads of a CUDA block, it only specifies how many elements will be handled by each block.

https://github.com/openai/triton/blob/5d47054a05bdd48624d69fdb94e29023de233e65/python/triton/compiler/compiler.py#L695

srush commented 1 year ago

Great, so that answers half my question. It sounds like CTA here for cuda should be 128 and that corresponds to 32 TpW * 4 WpCTA.

But I still don't know how to call this as a CUDA function? What values should I pass in for GRID, BLOCK, Shared memory?

(TTGIR)

#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
  func public @gptq_kernel_0d1d2d3d4d5d(%arg0: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) {
    %cst = arith.constant dense<8> : tensor<1024xi32, #blocked>
srush commented 1 year ago

Oh, I think I get it now.

If I have a Triton BLOCK_SIZE of 1024, I should still use <<TOTAL/ BLOCKSIZE, 128, shared_mem>> in CUDA and it will correspond to the original Triton function. 128 = 32 * num_warps.

srush commented 1 year ago

This worked for me for simple functions, but now it seems to be failing for more complex functions. I'm getting weird bugs with alignment and outputs.

Besides grid and shared memory are there any other settings I need to export to call the PTX code outside of python? I noticed for instance the number of args in the PTX sometimes differs from the original function (it seems to stop at 12). Is there a better way to use Triton outside of python?

Jokeren commented 1 year ago

You can check the AOT module.

jeromeku commented 9 months ago

@srush

Wrote some tools for exactly this purpose. See here and here.

Reason for the differing function signature in the compiled ptx / cubin from original is that some of the non-tensor args are treated as compile-time constants (even those not marked as tl.constexpr) and no longer part of the final sig. The metadata generated during the compilation pipeline, specifically, the specializations should shed more light.

The PR per above handles these cases automatically.

Happy to adapt these tools for your use case.

wbrickner commented 4 months ago

Hi, this trick for exporting PTX of a JIT'd kernel no longer works, the organization of cache etc seems to have changed since summer 2023.

What is now the correct method for accessing PTX / other compiled kernel format?

hubertlu-tw commented 1 day ago

Hi, this trick for exporting PTX of a JIT'd kernel no longer works, the organization of cache etc seems to have changed since summer 2023.

What is now the correct method for accessing PTX / other compiled kernel format?

@wbrickner I was actually able to export the generated PTX using the trick mentioned above on Google Colab for T4 GPU with PyTorch Version: 2.4.1+cu121 and Triton version 3.1.0 (Oct 14, 2024).

# Access the cache of compiled kernels
kernel_cache = your_triton_kernel.cache     # You need to launch the kernel first so that the PTX is generated in cache.

# Get the first cached compilation (or iterate over all)
cache_entry = kernel_cache[0]

compiled_kernel = list(cache_entry.values())[0]

# Access the 'asm' dictionary
asm_dict = compiled_kernel.asm

# List all keys in the 'asm' dictionary
all_keys = asm_dict.keys()
print("All keys in asm dictionary:", all_keys) # All keys in asm dictionary: dict_keys(['ttir', 'ttgir', 'llir', 'ptx', 'cubin'])

with open("your_triton_kernel.ptx", "w") as a:
    print(asm_dict['ptx'], file=a)