jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.59k stars 2.82k forks source link

[Pallas TPU] [WIP] Add vector support to `pl.debug_print` #25099

Open copybara-service[bot] opened 3 hours ago

copybara-service[bot] commented 3 hours ago

[Pallas TPU] [WIP] Add vector support to pl.debug_print