However, the same program fails without a device because the XLA/JAX runtime performs some renaming between "gpu" and "cuda"/"rocm" platforms, presumably for compatibility with legacy scripts. Here is the error message we get when compiling deviceless:
RuntimeError: Unknown backend: 'gpu' requested, but no platforms that are instances of gpu are present. ...
If we bypass some of the renaming, for example with the patch below, the deviceless compilation and serialization succeed.
Here is the patch that makes deviceless compilation succeed for NVIDIA GPUs:
diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py
index 1d3c50403..d1d56334d 100644
--- a/jax/_src/xla_bridge.py
+++ b/jax/_src/xla_bridge.py
@@ -798,6 +798,7 @@ def canonicalize_platform(platform: str) -> str:
for p in platforms:
if p in b.keys():
return p
+ return "cuda"
raise RuntimeError(f"Unknown backend: '{platform}' requested, but no "
f"platforms that are instances of {platform} are present. "
"Platforms present are: " + ",".join(b.keys()))
Alternatively, if one removes the renaming in XLA, deviceless AOT compilation also passes:
diff --git a/xla/python/py_client.h b/xla/python/py_client.h
index 374b7f6d2e..73543e91a6 100644
--- a/xla/python/py_client.h
+++ b/xla/python/py_client.h
@@ -96,16 +96,7 @@ class PyClient {
}
std::string_view platform_name() const {
- // TODO(phawkins): this is a temporary backwards compatibility shim. We
- // changed the name PJRT reports for GPU platforms to "cuda" or "rocm", but
- // we haven't yet updated JAX clients that expect "gpu". Migrate users and
- // remove this code.
- if (ifrt_client_->platform_name() == "cuda" ||
- ifrt_client_->platform_name() == "rocm") {
- return "gpu";
- } else {
- return ifrt_client_->platform_name();
- }
+ return ifrt_client_->platform_name();
}
std::string_view platform_version() const {
return ifrt_client_->platform_version();
Ahead-of-time compilation and serialization code:
import jax
import jax.numpy as jp
import jax.experimental.topologies as topologies
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
from jax.experimental.serialize_executable import (
deserialize_and_load,
serialize,
)
# Contents of https://github.com/openxla/xla/blob/main/xla/tools/hlo_opt/gpu_specs/a100_pcie_80.txtpb
target_config_proto = """gpu_device_info {
threads_per_block_limit: 1024
threads_per_warp: 32
shared_memory_per_block: 49152
shared_memory_per_core: 167936
threads_per_core_limit: 2048
core_count: 108
fpus_per_core: 64
block_dim_limit_x: 2147483647
block_dim_limit_y: 65535
block_dim_limit_z: 65535
memory_bandwidth: 2039000000000
l2_cache_size: 41943040
clock_rate_ghz: 1.1105
device_memory_size: 79050250240
shared_memory_per_block_optin: 166912
cuda_compute_capability {
major: 8
}
registers_per_core_limit: 65536
registers_per_block_limit: 65536
}
platform_name: "CUDA"
dnn_version_info {
major: 8
minor: 3
patch: 2
}
device_description_str: "A100 80GB"
"""
# Requested topology:
# 1 machine
# 1 process per machine
# 2 devices per process
topo = topologies.get_topology_desc(
"topo",
"cuda",
target_config=target_config_proto,
topology="1x1x2")
# Create the mesh and sharding.
mesh = Mesh(topo.devices, ('x',))
s = NamedSharding(mesh, P('x', None))
def fn(x):
return jp.sum(x * x)
# JIT with fully specified shardings.
fn = jax.jit(fn, in_shardings=s, out_shardings=NamedSharding(mesh, P()))
# Provide input shape(s).
x_shape = jax.ShapeDtypeStruct(
shape=(16, 16),
dtype=jp.dtype('float32'),
sharding=s)
# Lower and compile.
lowered = fn.lower(x_shape)
compiled = lowered.compile()
# Serialize the compilation results.
serialized, in_tree, out_tree = serialize(compiled)
print("Executable compiled and serialized")
# Write the serialized code to a file.
fname = "square.xla.bin"
with open(fname, "wb") as binary_file:
binary_file.write(serialized)
print(f"Executable saved to '{fname}'")
System info (python version, jaxlib version, accelerator, etc.)
Description
With a local GPU device, one can compile ahead-of-time even for a different GPU and topology, as illustrated by the code at the end of this report (this only works with recent XLA - https://github.com/openxla/xla/pull/16913). For deserialization code, see https://gist.github.com/jaro-sevcik/3495718bb04c6096c0f998fc29220c2b.
However, the same program fails without a device because the XLA/JAX runtime performs some renaming between "gpu" and "cuda"/"rocm" platforms, presumably for compatibility with legacy scripts. Here is the error message we get when compiling deviceless:
If we bypass some of the renaming, for example with the patch below, the deviceless compilation and serialization succeed.
Here is the patch that makes deviceless compilation succeed for NVIDIA GPUs:
Alternatively, if one removes the renaming in XLA, deviceless AOT compilation also passes:
Ahead-of-time compilation and serialization code:
System info (python version, jaxlib version, accelerator, etc.)