jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.59k stars 2.82k forks source link

[Lowering] Stable IR #25123

Open yliu120 opened 2 hours ago

yliu120 commented 2 hours ago

We have observed some runtime bits inside the IR when using the following.

Currently the host callbacks will be translated to a custom call taking the host callback pointer as the first argument. And that runtime callback pointer is different for each process. In that case, the lowered stablehlo becomes different across different runs for the same program.

We should build a registry for the CpuCallbacks and GpuCallbacks so that the descriptor can be a virtual ID rather than physical function pointers. If our program is always lowered in a deterministic order, we will get stable virtual ID for every run across processes.

yliu120 commented 2 hours ago

@hawkinsp