triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.56k stars 1.67k forks source link

Index in triton #974

Open jiangzzsss opened 1 year ago

jiangzzsss commented 1 year ago

We'd like to do some indexing in triton kernels, say we have x_ptr, idx_ptr, out_ptr

x = tl.load(x_ptr + offsets, mask = mask)
idx = tl.load(idx_ptr + offsets, mask = mask)

we have: 1.

idx = idx.to(tl.int32)
output = tl.load(x_ptr + idx)

it works 2.

output = tl.zeros([BLOCK_SIZE, ], dtype=tl.float32)
for i in range(0, BLOCK_SIZE):
      output[i] = x[idx[i]]

it reports errors. (error message is put at last) **we want to know:

  1. if we using approach 1, is it memory efficient ? since we use load.
  2. if we try x[0], it also errors: "TypeError: 'constexpr' object is not iterable" we didn't see a lot in the docs , so are there any other ways of doing indexing ?**

we using Triton Version: 2.0.0.dev20221120, python 3.8.0 and run on A100 error logs of approach 2:

Traceback (most recent call last):
  File "<string>", line 21, in tri_index_kernel
KeyError: ('2-.-0-.-0-1e8410f206c822547fb50e2ea86e45a6-2b0c5161c53c71b37ae20a9996ee4bb8-3aa563e00c5c695dd945e23b09a86848-42648570729a4835b21c1c18cebedbfe-ff946bd4b3b4a4cbdf8cedc6e1c658e0-5c5e32ff210f3b7f56c98ca29917c25e-06f0df2d61979d629033f4a22eff5198-0dd03b0bd512a184b3512b278d9dfa59-d35ab04ae841e2714a253c523530b071', (torch.float32, torch.float32, torch.float32, 'i32'), (64,), (True, True, True, (True, False)))

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/mnt/cache/jiangzhen/.local/lib/python3.8/site-packages/triton/compiler.py", line 838, in make_triton_ir
    generator.visit(fn.parse())
  File "/mnt/cache/jiangzhen/.local/lib/python3.8/site-packages/triton/compiler.py", line 771, in visit
    return super().visit(node)
  File "/mnt/cache/share/spring/conda_envs/miniconda3/envs/s0.3.6_py38/lib/python3.8/ast.py", line 360, in visit
    return visitor(node)
  File "/mnt/cache/jiangzhen/.local/lib/python3.8/site-packages/triton/compiler.py", line 260, in visit_Module
    ast.NodeVisitor.generic_visit(self, node)
  File "/mnt/cache/share/spring/conda_envs/miniconda3/envs/s0.3.6_py38/lib/python3.8/ast.py", line 368, in generic_visit
    self.visit(item)
  File "/mnt/cache/jiangzhen/.local/lib/python3.8/site-packages/triton/compiler.py", line 771, in visit
    return super().visit(node)
  File "/mnt/cache/share/spring/conda_envs/miniconda3/envs/s0.3.6_py38/lib/python3.8/ast.py", line 360, in visit
    return visitor(node)
  File "/mnt/cache/jiangzhen/.local/lib/python3.8/site-packages/triton/compiler.py", line 320, in visit_FunctionDef
    has_ret = self.visit_compound_statement(node.body)
  File "/mnt/cache/jiangzhen/.local/lib/python3.8/site-packages/triton/compiler.py", line 254, in visit_compound_statement
    self.last_ret = self.visit(stmt)
  File "/mnt/cache/jiangzhen/.local/lib/python3.8/site-packages/triton/compiler.py", line 771, in visit
    return super().visit(node)
  File "/mnt/cache/share/spring/conda_envs/miniconda3/envs/s0.3.6_py38/lib/python3.8/ast.py", line 360, in visit
    return visitor(node)
  File "/mnt/cache/jiangzhen/.local/lib/python3.8/site-packages/triton/compiler.py", line 648, in visit_For
    self.visit_compound_statement(node.body)
  File "/mnt/cache/jiangzhen/.local/lib/python3.8/site-packages/triton/compiler.py", line 254, in visit_compound_statement
    self.last_ret = self.visit(stmt)
  File "/mnt/cache/jiangzhen/.local/lib/python3.8/site-packages/triton/compiler.py", line 771, in visit
    return super().visit(node)
  File "/mnt/cache/share/spring/conda_envs/miniconda3/envs/s0.3.6_py38/lib/python3.8/ast.py", line 360, in visit
    return visitor(node)
  File "/mnt/cache/jiangzhen/.local/lib/python3.8/site-packages/triton/compiler.py", line 364, in visit_Assign
    _names += [self.visit(target)]
  File "/mnt/cache/jiangzhen/.local/lib/python3.8/site-packages/triton/compiler.py", line 771, in visit
    return super().visit(node)
  File "/mnt/cache/share/spring/conda_envs/miniconda3/envs/s0.3.6_py38/lib/python3.8/ast.py", line 360, in visit
    return visitor(node)
  File "/mnt/cache/jiangzhen/.local/lib/python3.8/site-packages/triton/compiler.py", line 576, in visit_Subscript
    assert node.ctx.__class__.__name__ == "Load"
AssertionError

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "ztest.py", line 51, in <module>
    output = tri_index(x, idx)
  File "ztest.py", line 44, in tri_index
    tri_index_kernel[grid](x, idx, output, n_elements, BLOCK_SIZE=64)
  File "/mnt/cache/jiangzhen/.local/lib/python3.8/site-packages/triton/runtime/jit.py", line 106, in launcher
    return self.run(*args, grid=grid, **kwargs)
  File "<string>", line 41, in tri_index_kernel
  File "/mnt/cache/jiangzhen/.local/lib/python3.8/site-packages/triton/compiler.py", line 1256, in compile
    asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages,
  File "/mnt/cache/jiangzhen/.local/lib/python3.8/site-packages/triton/compiler.py", line 892, in _compile
    module, _ = make_triton_ir(fn, signature, specialization, constants)
  File "/mnt/cache/jiangzhen/.local/lib/python3.8/site-packages/triton/compiler.py", line 843, in make_triton_ir
    raise CompilationError(fn.src, node) from e
triton.compiler.CompilationError: at 23:8:
def tri_index_kernel(
    x_ptr,  # *Pointer* to first input vector
    idx_ptr,  # *Pointer* to second input vector
    output_ptr,  # *Pointer* to output vector
    n_elements,  # Size of the vector
    BLOCK_SIZE: tl.constexpr,  # Number of elements each program should process
                 # NOTE: `constexpr` so it can be used as a shape value
):
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements

    x = tl.load(x_ptr + offsets, mask = mask)
    idx = tl.load(idx_ptr + offsets, mask = mask)
    output = tl.zeros([BLOCK_SIZE, ], dtype=tl.float32)
    min_off = tl.min(offsets, axis=0)
    max_off = tl.max(offsets, axis=0)
    # idx //= 1
    idx = idx.to(tl.int32)
    output = tl.load(x_ptr + idx)
    for i in range(0, BLOCK_SIZE):
        output[i] = x[idx[i]]
        ^
ptillet commented 1 year ago

Yeah, on-chip indexing through shared memory isn't supported yet. It's on the roadmap though, but it's a pretty advanced feature so we haven't come up with a specific timeline yet.

jiangzzsss commented 1 year ago

Thanks for the reply! Although I am still curious that if we store the values back and use indexes as ptrs to load them, will this be slow ?

Jokeren commented 1 year ago

Thanks for the reply! Although I am still curious that if we store the values back and use indexes as ptrs to load them, will this be slow ?

It's supposed to be slow since you store values on the global memory. Though in some cases you will go through the cache.

nlgranger commented 4 months ago

Triton just raises an assertion error when trying to index a local tensor. I suppose it is related to this issue. Are there any workarounds?

marcelroed commented 3 months ago

Any updates on this? Is there still no way to do indexing in a Triton kernel?

jselvam11 commented 2 months ago

https://github.com/facebookresearch/xformers/blob/main/xformers/ops/_triton/k_index_select_cat.py

There’s this in xformers seems similar to indexing into a sparse tensor

nlgranger commented 2 months ago

https://github.com/facebookresearch/xformers/blob/main/xformers/ops/_triton/k_index_select_cat.py

There’s this in xformers seems similar to indexing into a sparse tensor

Yes but it goes through global memory which is slow as mentioned by @Jokeren.

cagrikymk commented 1 month ago

I have a similar issue but I only want to index different blocks such as (to compute a spline function up to a certain order):

data = tl.zeros((4, BLOCK_SIZE))
data[0] = w
data[1] = 1 - w
.....

I get similar kind of compiler error but this issue could be easily fixed by creating 4 different shared memory blocks (each with a specific name). In that case, iterating over these blocks with a for loop becomes the issue.

I think I can unroll and name everything to overcome the problem but that would produce unmaintainable code. Is there a known trick to get this to work other than going through global memory?