triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.43k stars 1.65k forks source link

Flip or reverse associative scan? #2930

Open srush opened 10 months ago

srush commented 10 months ago

I am trying to figure out if there is any clever way to flip a tensor along an axis or run a right-to-left associative sum? I know I can load a tensor in the reverse direction, but ideally I would like to be able to do this without reloading all the tensors in reverse order.

Jokeren commented 10 months ago

It's not possible for now. May I know what's your use case?

srush commented 10 months ago

https://github.com/srush/annotated-mamba/issues/1

Jokeren commented 10 months ago

It's not difficult to add a reverse option and modify our backend to support reverse=True. Unfortunately I'm occupied with other stuff, not sure if @ThomasRaoux is interested?

Jokeren commented 10 months ago

@srush or do you have any student interested in adding the reverse=True option? It might be interesting option since the support is not on the critical path and the student can get a chance to learn more about triton's backend.

srush commented 10 months ago

Let me see if I can find a workaround, otherwise, sure, that would be fun.

proger commented 10 months ago

Is reading the memory in reverse inside ScanOpToLLVM a good idea? I am taking this route in my CUDA implementation and the algorithm should be close to the one in ScanOpToLLVM. I'd be down to write some MLIR C++ for this.

srush commented 10 months ago

Here's an in-memory reverse hack that seems to work for me. Unfortunately with tl.dot I get a segfault. (any recs for debugging those?)

L = 32
@triton.jit
def reverse(x_ptr, z_ptr):
    Ls = tl.arange(0, L)
    x = tl.load(x_ptr + L * Ls[:, None] + Ls)[:, None, :]
    y = (Ls[:, None] == L - Ls - 1)   
    z = tl.sum(x * y, 2)
    tl.store(z_ptr + L * Ls[:, None] + Ls, z)

image

Jokeren commented 10 months ago

Thanks for the solution. Seems like uses similar idea as triton's sort.

tl.dot doesn't support 3d matrix multiplications. Since this is a hack, either tl.dot or tl.sum will be slower than the native reverse=True version anyway

srush commented 10 months ago

Just to be clear the tl.dot version is a 2D mat mul. (code below that segfaults)

But yes I agree that a reverse=True is the best way. Unfortunately I am now running into a bug with tl.associativescan giving the wrong answer on the forward pass, so I am trying to debug that :cry:

L = 32
@triton.jit
def reverse(x_ptr, z_ptr):
    Ls = tl.arange(0, L)
    x = tl.load(x_ptr + L * Ls[:, None] + Ls)
    y = (Ls[:, None] == L - Ls - 1).to(tl.float32) # forget the exact syntax
    z = tl.dot(x * y, 2)
    tl.store(z_ptr + L * Ls[:, None] + Ls, z)
Jokeren commented 10 months ago

Unfortunately I am now running into a bug with tl.associativescan giving the wrong answer on the forward pass, so I am trying to debug that 😢

Please let us know if the bug is caused by triton

confucianism72 commented 10 months ago

Although I cannot understand how this hack runs but I testes it and found I works well. It seems that @srush's code $L$ must be in must be power of 2.

srush commented 10 months ago

Sorry, shouldn't have called it a hack. Here's an explanation of what it is doing.

image

However as @Jokeren notes my method requires creating an B x L x L intermediate. This wouldn't be a problem, but tl.dot seems pretty broken in that the following triton version segfaults for me.

import triton
import triton.language as tl
import torch
L = 32
@triton.jit
def reverse(x_ptr, z_ptr):
    Ls = tl.arange(0, L)
    x = tl.load(x_ptr + L * Ls[:, None] + Ls)
    M = (Ls[:, None] == L - Ls - 1).to(tl.float32)
    z = tl.dot(x, M)
    tl.store(z_ptr + L * Ls[:, None] + Ls, z)

x = (torch.arange(L) + torch.zeros(L, L)).float().cuda()
z = (torch.arange(L) + torch.zeros(L, L)).float().cuda()
reverse[(1,)](x, z)
confucianism72 commented 10 months ago

Well, the above code works well for me. It did not report an error. Maybe you need to check your triton version?

srush commented 10 months ago

What's your version? I was using nightly.

confucianism72 commented 10 months ago

I am using triton 2.1.0, python 3.11.5. I directly copied your code and ran it in Jupyter notebook.

confucianism72 commented 10 months ago
image