jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.59k stars 2.82k forks source link

[sharding_in_types] Make argmax and argmin work with sharding_in_types. This also requires adding `reduce_p` sharding rule #25064

Closed copybara-service[bot] closed 3 days ago

copybara-service[bot] commented 3 days ago

[sharding_in_types] Make argmax and argmin work with sharding_in_types. This also requires adding reduce_p sharding rule