triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.36k stars 1.64k forks source link

Torch trace replace pattern with triton kernel errors- Noob issue #2684

Open cm2435 opened 11 months ago

cm2435 commented 11 months ago

GitHub Issue

Hello! Sorry for the slightly crossposty issue but this seems like the most likely place I could post this and get some help.

I'm trying a basic 'hello world' In subbing out operator patterns from a symbolic trace of a PyTorch module to try to get to grips with really integrating the triton kernels I write with PyTorch native code.

To replicate this, I took the tutorial-1 code on vector addition:

import torch 
import triton 
import triton.language as tl 

@triton.autotune(configs = [
    triton.Config({'BLOCK_SIZE': 128}, num_warps = 4,),
    triton.Config({'BLOCK_SIZE': 1024}, num_warps = 8,),
], key = ['n_elements'])

@triton.jit 
def add_kernel(
    x_ptr, 
    y_ptr,
    output_ptr,
    n_elements, 
    BLOCK_SIZE : tl.constexpr
    ): 
    block_pid = tl.program_id(axis = 0)
    block_start = block_pid * BLOCK_SIZE
    block_offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = block_offsets < n_elements

    x = tl.load(x_ptr + block_offsets, mask = mask)
    y = tl.load(y_ptr + block_offsets, mask = mask) 

    output = x + y

    tl.store(output_ptr + block_offsets, output, mask=mask)

def triton_add(x : torch.Tensor, y : torch.Tensor): 
    output = torch.empty_like(x)
    n_elements = output.numel()

    torch._assert(x.is_cuda and y.is_cuda and output.is_cuda, message= "All devices must be on cuda")

    grid = lambda meta : (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)

    add_kernel[grid](x,y, output, n_elements)

    return output

And wrote the most basic PyTorch module I could come up with:

class VectorAdd(torch.nn.Module): 
    '''
    Most basic example of a torch module I can think of. Just does vector addition
    '''
    def __init__(self): 
        super().__init__()

    def forward(self, x: torch.Tensor, y: torch.Tensor): 
        addition = x + y
        return addition + 1

Following the PyTorch documentation on operator pattern substitution, it seems to trace fine (as it should):

from torch.fx import symbolic_trace

symbolic_traced : torch.fx.GraphModule = symbolic_trace(VectorAdd())
print("Code:")
print(symbolic_traced.code)
print("Tensor graph")
print(symbolic_traced.graph.print_tabular())

But the code

def sub_addition(x : torch.Tensor, y :  torch.Tensor): 
    return x + y

replace_pattern(symbolic_traced, sub_addition, triton_add)

fails with the error

---------------------------------------------------------------------------
TraceError                                Traceback (most recent call last)
...
TraceError: symbolically traced variables cannot be used as inputs to control flow

I understand this has to do with the PyTorch compiler being unable to symbolically trace code with dynamic control flow. The specific error point for this is the launch kernel for 'add_kernel'; I just wanted to see if anyone had any advice on how to properly do this and what I'm missing.

Much thanks.

cm2435 commented 11 months ago

For relevancy here are my system settings.

OS- ubuntu 22.04 NVCC- 12.3 | NVIDIA-SMI 525.85.05 Driver Version: 525.85.05

the underlying card is a rtx 3090.

Tutorials on this would be great if anyone has one. I tried to follow along with the implementation of @pommedeterresautee in the kernl package and wrap the pytorch file in a pytorch.autograd.Function class and then wrap the .apply of that in a function, but still the compiler complains about any kernel launch not being symbolically tracable.