google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
28.23k stars 2.58k forks source link

[export] Add backwards compatibility test for Pallas call on GPUs. #21036

Closed copybara-service[bot] closed 2 weeks ago

copybara-service[bot] commented 2 weeks ago

[export] Add backwards compatibility test for Pallas call on GPUs.

Note that this adds the minimum of safety net to protect against non-backwards-compatible changes. We really should have more tests that cover more of the Triton MLIR.

Also enable serialization of such calls.