openxla / xla

A machine learning compiler for GPUs, CPUs, and ML accelerators
Apache License 2.0
2.69k stars 433 forks source link

Deserialization of executables fails on non-zero ranks when deserializing single-device executable #18286

Open jaro-sevcik opened 3 weeks ago

jaro-sevcik commented 3 weeks ago

If we deserialize executables via the wrapper C API client, the compile options are ignored. In practice, this means that JAX compilation cache fails when deserializing executables for rank zero on non-zero ranks.

JAX repro:

import jax
import jax.numpy as jnp
import logging
import argparse
import socket

parser = argparse.ArgumentParser(
                    prog='mock-test',
                    description='Tests mocking',
                    epilog='...')
parser.add_argument('-r', '--rank', type=int)
args = parser.parse_args()

if args.rank == 1:
  logging.basicConfig(format='%(asctime)s %(message)s')
  logging.getLogger("jax._src.compiler").setLevel(logging.DEBUG)

jax.config.update("jax_compilation_cache_dir", "/tmp/compilation_cache")
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)

jax.distributed.initialize(socket.gethostname() + ":1234", 2, args.rank, local_device_ids = [args.rank])

f = jax.jit(lambda x: x)
jax.block_until_ready(f(jnp.zeros((1,))))

Run (python test.py -r0 &) && (python test.py -r1) twice on a machine with 2+ GPUs.

The output (from rank 1) then contains the following error message:

/opt/jax/jax/_src/compiler.py:691: UserWarning: Error reading persistent compilation cache entry for 'jit__lambda_':
XlaRuntimeError: INVALID_ARGUMENT: Device assignment (Computations: 1 Replicas: 1
Computation 0: 0
) does not have any local devices.
  warnings.warn(
2024-10-14 09:16:30,222 PERSISTENT COMPILATION CACHE MISS for 'jit__lambda_' with key 'jit__lambda_-502ff86f0064419e429f73e9641f94cc3ab91a275910dec17b3ba6186556a297'
loislo commented 2 days ago

the pr was reverted due to the crashes in production