Open jaro-sevcik opened 5 months ago
Yes. That's correct and how it works at the moment. The custom partitioning ultimately refers to a Python object, which is why it's not stable run to run.
If this only cover the python callback, could we just remove it from the key? Not everything is 100% versioned right now (like XLA isn't versioned). So it would just end up to the end user responsibility to make sure handle the cache correctly?
Or could we get the python ast of the callbacks and hash it? I see that as more work and not sure it is useful.
@hawkinsp As Frederic mentioned, could the python callbacks for custom_partitioning be removed from the hlo code used for generating the cache key?
Another alternative is to do a partial compilation and stop after the GSPMD. Then we would get the exact graph we need for the key. @hawkinsp What do you of those options?
Description
Compilation cache does not trigger for jitted functions with custom_partitioning ops. After running the JAX program multiple times, there is a separate entry with different hash for each run.
Here is the invocation:
Here are the contents of the cache afterwards:
The program with the custom-partitioned op:
I believe the cache misses because the
backend_config
parameter to theCustomSPMDPartitioning
custom call is the address of some descriptor data structure and that is different from run to run. As a result, the MLIR is different in different runs and the cache never triggers. Below is the HLO for the functionf
, thebackend_config
here is the address that differs across runs.System info (python version, jaxlib version, accelerator, etc.)