jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.27k stars 2.78k forks source link

support for concat and dynamic slice in pallas/triton #24292

Open juntang-zhuang opened 6 days ago

juntang-zhuang commented 6 days ago

Feature request for the pad and slicing operation inside pallas kernel, e.g.

x=jnp.pad(x, ((1,0), (2,0)); x=x[:, :-2]

justinjfu commented 2 days ago

After some initial investigation it seems like this is a non-trivial since pad/slice do not directly exist in the triton API (https://triton-lang.org/main/python-api/triton.language.html), and the concat op for arrays only supports reorder=True (https://triton-lang.org/main/python-api/generated/triton.language.cat.html#triton.language.cat). Will update here later if I find a better solution.