Currently, it's not possible to use jnp.select in pallas kernel or in BlockSpec. Trying to do so causes the follow error.
jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: Unimplemented primitive in Pallas TPU lowering: argmax. Please file an issue on https://github.com/google/jax/issues.
Using select allows better conditional control within blockspec. There's certain use case for condition branch, especially when selecting index.
Currently, it's not possible to use jnp.select in pallas kernel or in BlockSpec. Trying to do so causes the follow error.
jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: Unimplemented primitive in Pallas TPU lowering: argmax. Please file an issue on https://github.com/google/jax/issues.
Using select allows better conditional control within blockspec. There's certain use case for condition branch, especially when selecting index.