If we deserialize executables via the wrapper C API client, the compile options are ignored. In practice, this means that JAX compilation cache fails when deserializing executables for rank zero on non-zero ranks.
Run (python test.py -r0 &) && (python test.py -r1) twice on a machine with 2+ GPUs.
The output (from rank 1) then contains the following error message:
/opt/jax/jax/_src/compiler.py:691: UserWarning: Error reading persistent compilation cache entry for 'jit__lambda_':
XlaRuntimeError: INVALID_ARGUMENT: Device assignment (Computations: 1 Replicas: 1
Computation 0: 0
) does not have any local devices.
warnings.warn(
2024-10-14 09:16:30,222 PERSISTENT COMPILATION CACHE MISS for 'jit__lambda_' with key 'jit__lambda_-502ff86f0064419e429f73e9641f94cc3ab91a275910dec17b3ba6186556a297'
If we deserialize executables via the wrapper C API client, the compile options are ignored. In practice, this means that JAX compilation cache fails when deserializing executables for rank zero on non-zero ranks.
JAX repro:
Run
(python test.py -r0 &) && (python test.py -r1)
twice on a machine with 2+ GPUs.The output (from rank 1) then contains the following error message: