jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.47k stars 2.8k forks source link

[pallas] Allow user to pass 64-bit indices to `pl.{load,store,...}`. #24782

Open copybara-service[bot] opened 4 days ago

copybara-service[bot] commented 4 days ago

[pallas] Allow user to pass 64-bit indices to pl.{load,store,...}.