Open jiangzzsss opened 1 year ago
Yeah, on-chip indexing through shared memory isn't supported yet. It's on the roadmap though, but it's a pretty advanced feature so we haven't come up with a specific timeline yet.
Thanks for the reply! Although I am still curious that if we store the values back and use indexes as ptrs to load them, will this be slow ?
Thanks for the reply! Although I am still curious that if we store the values back and use indexes as ptrs to load them, will this be slow ?
It's supposed to be slow since you store values on the global memory. Though in some cases you will go through the cache.
Triton just raises an assertion error when trying to index a local tensor. I suppose it is related to this issue. Are there any workarounds?
Any updates on this? Is there still no way to do indexing in a Triton kernel?
https://github.com/facebookresearch/xformers/blob/main/xformers/ops/_triton/k_index_select_cat.py
There’s this in xformers seems similar to indexing into a sparse tensor
https://github.com/facebookresearch/xformers/blob/main/xformers/ops/_triton/k_index_select_cat.py
There’s this in xformers seems similar to indexing into a sparse tensor
Yes but it goes through global memory which is slow as mentioned by @Jokeren.
I have a similar issue but I only want to index different blocks such as (to compute a spline function up to a certain order):
data = tl.zeros((4, BLOCK_SIZE))
data[0] = w
data[1] = 1 - w
.....
I get similar kind of compiler error but this issue could be easily fixed by creating 4 different shared memory blocks (each with a specific name). In that case, iterating over these blocks with a for loop becomes the issue.
I think I can unroll and name everything to overcome the problem but that would produce unmaintainable code. Is there a known trick to get this to work other than going through global memory?
We'd like to do some indexing in triton kernels, say we have x_ptr, idx_ptr, out_ptr
we have: 1.
it works 2.
it reports errors. (error message is put at last) **we want to know:
we using Triton Version: 2.0.0.dev20221120, python 3.8.0 and run on A100 error logs of approach 2: