Open srush opened 10 months ago
It's not possible for now. May I know what's your use case?
It's not difficult to add a reverse
option and modify our backend to support reverse=True
. Unfortunately I'm occupied with other stuff, not sure if @ThomasRaoux is interested?
@srush or do you have any student interested in adding the reverse=True
option? It might be interesting option since the support is not on the critical path and the student can get a chance to learn more about triton's backend.
Let me see if I can find a workaround, otherwise, sure, that would be fun.
Is reading the memory in reverse inside ScanOpToLLVM a good idea? I am taking this route in my CUDA implementation and the algorithm should be close to the one in ScanOpToLLVM. I'd be down to write some MLIR C++ for this.
Here's an in-memory reverse hack that seems to work for me. Unfortunately with tl.dot
I get a segfault. (any recs for debugging those?)
L = 32
@triton.jit
def reverse(x_ptr, z_ptr):
Ls = tl.arange(0, L)
x = tl.load(x_ptr + L * Ls[:, None] + Ls)[:, None, :]
y = (Ls[:, None] == L - Ls - 1)
z = tl.sum(x * y, 2)
tl.store(z_ptr + L * Ls[:, None] + Ls, z)
Thanks for the solution. Seems like uses similar idea as triton's sort
.
tl.dot
doesn't support 3d matrix multiplications.
Since this is a hack, either tl.dot
or tl.sum
will be slower than the native reverse=True
version anyway
Just to be clear the tl.dot
version is a 2D mat mul. (code below that segfaults)
But yes I agree that a reverse=True
is the best way. Unfortunately I am now running into a bug with tl.associativescan
giving the wrong answer on the forward pass, so I am trying to debug that :cry:
L = 32
@triton.jit
def reverse(x_ptr, z_ptr):
Ls = tl.arange(0, L)
x = tl.load(x_ptr + L * Ls[:, None] + Ls)
y = (Ls[:, None] == L - Ls - 1).to(tl.float32) # forget the exact syntax
z = tl.dot(x * y, 2)
tl.store(z_ptr + L * Ls[:, None] + Ls, z)
Unfortunately I am now running into a bug with tl.associativescan giving the wrong answer on the forward pass, so I am trying to debug that 😢
Please let us know if the bug is caused by triton
Although I cannot understand how this hack runs but I testes it and found I works well. It seems that @srush's code $L$ must be in must be power of 2.
Sorry, shouldn't have called it a hack. Here's an explanation of what it is doing.
However as @Jokeren notes my method requires creating an B x L x L intermediate. This wouldn't be a problem, but tl.dot seems pretty broken in that the following triton version segfaults for me.
import triton
import triton.language as tl
import torch
L = 32
@triton.jit
def reverse(x_ptr, z_ptr):
Ls = tl.arange(0, L)
x = tl.load(x_ptr + L * Ls[:, None] + Ls)
M = (Ls[:, None] == L - Ls - 1).to(tl.float32)
z = tl.dot(x, M)
tl.store(z_ptr + L * Ls[:, None] + Ls, z)
x = (torch.arange(L) + torch.zeros(L, L)).float().cuda()
z = (torch.arange(L) + torch.zeros(L, L)).float().cuda()
reverse[(1,)](x, z)
Well, the above code works well for me. It did not report an error. Maybe you need to check your triton version?
What's your version? I was using nightly.
I am using triton 2.1.0, python 3.11.5. I directly copied your code and ran it in Jupyter notebook.
I am trying to figure out if there is any clever way to flip a tensor along an axis or run a right-to-left associative sum? I know I can load a tensor in the reverse direction, but ideally I would like to be able to do this without reloading all the tensors in reverse order.