Open juntang-zhuang opened 6 days ago
After some initial investigation it seems like this is a non-trivial since pad/slice do not directly exist in the triton API (https://triton-lang.org/main/python-api/triton.language.html), and the concat op for arrays only supports reorder=True (https://triton-lang.org/main/python-api/generated/triton.language.cat.html#triton.language.cat). Will update here later if I find a better solution.
Feature request for the pad and slicing operation inside pallas kernel, e.g.
x=jnp.pad(x, ((1,0), (2,0)); x=x[:, :-2]