triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
12.91k stars 1.57k forks source link

Can I update the data in SRAM by index? #2643

Open fangwei123456 opened 10 months ago

fangwei123456 commented 10 months ago

Hi, I want to implement a conv1d in triton. Considering that some data of the input will be re-used, e.g., y[0] = conv(w, x[0:2]) and y[1] = conv(w, x[1:3], then I try to update the data in SRAM by index:

x[0: k - 1] = x[1: ]
x[k - 1] = tl.load(x_seq_ptr + row_ptrs)
h = tl.dot(weight, x)
tl.store(h_seq_ptr + row_ptrs, h)

But then I will get an AssertionError:

File "/home/wfang/anaconda3/envs/pytorch-env/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 733, in visit_Subscript
    assert node.ctx.__class__.__name__ == "Load"
AssertionError

Here are the complete codes:

import torch

import triton
import triton.language as tl

@triton.jit
def sliding_psn_forward_kernel(
        x_seq_ptr, weight_ptr, h_seq_ptr,
        T, N, k,
        n_stride,
        BLOCK_SIZE_T: tl.constexpr,
        BLOCK_SIZE_N: tl.constexpr
):
    pid = tl.program_id(0)
    num_pid_t = tl.cdiv(T, BLOCK_SIZE_T)
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
    pid_t = pid // num_pid_n
    pid_n = pid % num_pid_n

    weight = tl.load(weight_ptr)
    x = tl.zeros(shape=[weight.shape[0], BLOCK_SIZE_N], dtype=weight.dtype)

    for t in range(BLOCK_SIZE_T):
        new_row_start_ptr = pid_n + (pid_t + t) * n_stride
        row_ptrs = new_row_start_ptr + tl.arange(0, BLOCK_SIZE_N)

        x[0: k - 1] = x[1: ]
        x[k - 1] = tl.load(x_seq_ptr + row_ptrs)
        h = tl.dot(weight, x)
        tl.store(h_seq_ptr + row_ptrs, h)

def sliding_psn_forward(x_seq: torch.Tensor, weight: torch.Tensor):
    T, N = x_seq.shape
    k = weight.shape[0]

    grid = lambda META: (
        triton.cdiv(T, META['BLOCK_SIZE_T']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
    )
    h_seq = torch.empty_like(x_seq)

    sliding_psn_forward_kernel[grid](
        x_seq, weight, h_seq,
        T, N, k,
        x_seq.stride(1),
        BLOCK_SIZE_T=4,
        BLOCK_SIZE_N=4,
    )
    return h_seq

device = 'cuda:0'
T = 8
N = 16
k = 2
x_seq = torch.rand([T, N], device=device)
weight = torch.rand([k], device=device)

h_seq = sliding_psn_forward(x_seq, weight)

print(h_seq)
yiakwy-xpu-ml-framework-team commented 6 months ago

I guess you cannot because tensor object does not allow inplace assignment .

But actually I have experiences where inplace assignment is allowed in Graphcore IPU chip. So this is just software constraints I think.