jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.62k stars 2.82k forks source link

[pallas] Request to add support for select/argmax primitive in Pallas TPU #22508

Open Lime-Cakes opened 4 months ago

Lime-Cakes commented 4 months ago

Currently, it's not possible to use jnp.select in pallas kernel or in BlockSpec. Trying to do so causes the follow error.

jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: Unimplemented primitive in Pallas TPU lowering: argmax. Please file an issue on https://github.com/google/jax/issues.

Using select allows better conditional control within blockspec. There's certain use case for condition branch, especially when selecting index.

Lime-Cakes commented 4 months ago

I'd like to add that using cond is an usable solution, but it's longer and uglier.