jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.13k stars 2.76k forks source link

Ahead-of-time compilation on GPU fails without a GPU device #23971

Open jaro-sevcik opened 1 week ago

jaro-sevcik commented 1 week ago

Description

With a local GPU device, one can compile ahead-of-time even for a different GPU and topology, as illustrated by the code at the end of this report (this only works with recent XLA - https://github.com/openxla/xla/pull/16913). For deserialization code, see https://gist.github.com/jaro-sevcik/3495718bb04c6096c0f998fc29220c2b.

However, the same program fails without a device because the XLA/JAX runtime performs some renaming between "gpu" and "cuda"/"rocm" platforms, presumably for compatibility with legacy scripts. Here is the error message we get when compiling deviceless:

RuntimeError: Unknown backend: 'gpu' requested, but no platforms that are instances of gpu are present. ...

If we bypass some of the renaming, for example with the patch below, the deviceless compilation and serialization succeed.

Here is the patch that makes deviceless compilation succeed for NVIDIA GPUs:

diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py
index 1d3c50403..d1d56334d 100644
--- a/jax/_src/xla_bridge.py
+++ b/jax/_src/xla_bridge.py
@@ -798,6 +798,7 @@ def canonicalize_platform(platform: str) -> str:
   for p in platforms:
     if p in b.keys():
       return p
+  return "cuda"
   raise RuntimeError(f"Unknown backend: '{platform}' requested, but no "
                      f"platforms that are instances of {platform} are present. "
                      "Platforms present are: " + ",".join(b.keys()))

Alternatively, if one removes the renaming in XLA, deviceless AOT compilation also passes:

diff --git a/xla/python/py_client.h b/xla/python/py_client.h
index 374b7f6d2e..73543e91a6 100644
--- a/xla/python/py_client.h
+++ b/xla/python/py_client.h
@@ -96,16 +96,7 @@ class PyClient {
   }

   std::string_view platform_name() const {
-    // TODO(phawkins): this is a temporary backwards compatibility shim. We
-    // changed the name PJRT reports for GPU platforms to "cuda" or "rocm", but
-    // we haven't yet updated JAX clients that expect "gpu". Migrate users and
-    // remove this code.
-    if (ifrt_client_->platform_name() == "cuda" ||
-        ifrt_client_->platform_name() == "rocm") {
-      return "gpu";
-    } else {
-      return ifrt_client_->platform_name();
-    }
+    return ifrt_client_->platform_name();
   }
   std::string_view platform_version() const {
     return ifrt_client_->platform_version();

Ahead-of-time compilation and serialization code:

import jax
import jax.numpy as jp
import jax.experimental.topologies as topologies
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
from jax.experimental.serialize_executable import (
    deserialize_and_load,
    serialize,
)

# Contents of https://github.com/openxla/xla/blob/main/xla/tools/hlo_opt/gpu_specs/a100_pcie_80.txtpb
target_config_proto = """gpu_device_info {
  threads_per_block_limit: 1024
  threads_per_warp: 32
  shared_memory_per_block: 49152
  shared_memory_per_core: 167936
  threads_per_core_limit: 2048
  core_count: 108
  fpus_per_core: 64
  block_dim_limit_x: 2147483647
  block_dim_limit_y: 65535
  block_dim_limit_z: 65535
  memory_bandwidth: 2039000000000
  l2_cache_size: 41943040
  clock_rate_ghz: 1.1105
  device_memory_size: 79050250240
  shared_memory_per_block_optin: 166912
  cuda_compute_capability {
    major: 8
  }
  registers_per_core_limit: 65536
  registers_per_block_limit: 65536
}
platform_name: "CUDA"
dnn_version_info {
  major: 8
  minor: 3
  patch: 2
}
device_description_str: "A100 80GB"
"""

# Requested topology:
# 1 machine
# 1 process per machine
# 2 devices per process
topo = topologies.get_topology_desc(
  "topo",
  "cuda",
  target_config=target_config_proto,
  topology="1x1x2")

# Create the mesh and sharding.
mesh = Mesh(topo.devices, ('x',))
s = NamedSharding(mesh, P('x', None))

def fn(x):
  return jp.sum(x * x)

# JIT with fully specified shardings.
fn = jax.jit(fn, in_shardings=s, out_shardings=NamedSharding(mesh, P()))

# Provide input shape(s).
x_shape = jax.ShapeDtypeStruct(
          shape=(16, 16),
          dtype=jp.dtype('float32'),
          sharding=s)

# Lower and compile.
lowered = fn.lower(x_shape)
compiled = lowered.compile()

# Serialize the compilation results.
serialized, in_tree, out_tree = serialize(compiled)
print("Executable compiled and serialized")

# Write the serialized code to a file.
fname = "square.xla.bin"
with open(fname, "wb") as binary_file:
    binary_file.write(serialized)

print(f"Executable saved to '{fname}'")

System info (python version, jaxlib version, accelerator, etc.)

>>> import jax; jax.print_environment_info()
jax:    0.4.34.dev20240926+b6d668e0d
jaxlib: 0.4.34.dev20240927
numpy:  1.26.4
python: 3.10.12 (main, Sep 11 2024, 15:47:36) [GCC 11.4.0]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='...', release='5.4.0-92-generic', version='#103-Ubuntu SMP Fri Nov 26 16:13:00 UTC 2021', machine='x86_64')
jaro-sevcik commented 1 week ago

@hawkinsp , could you take a look?