jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.63k stars 2.82k forks source link

Track mapping of platform aliases to compile-only backends #25107

Open jaro-sevcik opened 5 days ago

jaro-sevcik commented 5 days ago

The patch tracks the mapping of aliases to compile-only backend platform names. The mapping enables canonicalizing platform names correctly ('gpu' -> 'cuda') when we only have compile-only backends for the platform.

jaro-sevcik commented 5 days ago

This is an alternative to https://github.com/jax-ml/jax/pull/25033 to address https://github.com/jax-ml/jax/issues/23971.