jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.63k stars 2.82k forks source link

[Mosaic:TPU] Add relayout for adding minor implicit dim and relax some offset restrictions on similar shape cast #25105

Open copybara-service[bot] opened 5 days ago

copybara-service[bot] commented 5 days ago

[Mosaic:TPU] Add relayout for adding minor implicit dim and relax some offset restrictions on similar shape cast

This factors out some logic from the apply-vector-layout shape cast rule where we insert a minor dimension, relaxes some offset restrictions on it, and uses it for the relayout.