jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.06k stars 2.75k forks source link

Unable to build jaxlib with debug symbols for GPU #23564

Open juuso-oskari opened 2 weeks ago

juuso-oskari commented 2 weeks ago

Description

I try to build the jaxlib with debug symbols for the xla with the following command:

python build/build.py --enable_cuda --bazel_options=--override_repository=xla=/xla --bazel_options=--jobs=1 --bazel_options=--compilation_mode=dbg

The build goes fine all the way up till the end when it tries to link the xla_extension.so:

[1 / 2] Linking external/xla/xla/python/xla_extension.so; 102s local
ERROR: /root/.cache/bazel/_bazel_root/bee4ad1fd43279be7a03b33426e824d5/external/xla/xla/python/BUILD:1257:21: Linking external/xla/xla/python/xla_extension.so failed: (Exit 1): crosstool_wrapper_driver_is_not_gcc failed: error executing command (from target @xla//xla/python:xla_extension.so) 
  (cd /root/.cache/bazel/_bazel_root/bee4ad1fd43279be7a03b33426e824d5/execroot/__main__ && \
  exec env - \
    LD_LIBRARY_PATH=/opt/amazon/efa/lib:/usr/local/cuda/compat/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 \
    PATH=/opt/amazon/efa/bin:/usr/local/mpi/bin:/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/local/ucx/bin \
    PWD=/proc/self/cwd \
    TF_CUDA_COMPUTE_CAPABILITIES=sm_50,sm_60,sm_70,sm_80,compute_90 \
  external/local_config_cuda/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc @bazel-out/k8-dbg/bin/external/xla/xla/python/xla_extension.so-2.params)
# Configuration: 5ca7bfb6889cfb0ee4db260a87240da2629ea60addd0c8f57834d31891b46935
# Execution platform: @local_execution_config_platform//:platform
collect2: fatal error: ld terminated with signal 9 [Killed]
compilation terminated.
[2 / 2] checking cached actions
Target //jaxlib/tools:build_wheel failed to build
INFO: Elapsed time: 162.700s, Critical Path: 160.90s
INFO: 2 processes: 2 internal.
FAILED: Build did NOT complete successfully

I don't get this error if I don't pass the argument --enable_cuda so probably it has something to do with the CUDA. But then again, passing --enable_cuda works if I don't try to build with debug symbols (so not passing --bazel_options=--compilation_mode=dbg).

System info (python version, jaxlib version, accelerator, etc.)

python version: Python 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] on linux jax: latest commit https://github.com/google/jax.git xla:latest commit https://github.com/openxla/xla accelerator (info from nvidia-smi): NVIDIA-SMI 560.28.03, Driver Version: 560.28.03, CUDA Version: 12.6, NVIDIA GeForce RTX 4080 Laptop GPU

To reproduce the error with a docker container:

# download local jax and xla repos
git clone https://github.com/openxla/xla.git
git clone https://github.com/google/jax
# start up nvidia docker container for jax
sudo docker run --name jax --gpus all --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 -it -d -v $PWD/jax:/jax -v $PWD/xla:/xla -v ~/.cache/bazel:/root/.cache/bazel nvcr.io/nvidia/jax:24.04-py3
# enter container in interactive mode
docker exec -it jax bash
# build jaxlib (produces the error)
cd jax
python build/build.py --enable_cuda --bazel_options=--override_repository=xla=/xla --bazel_options=--jobs=4 --bazel_options=--compilation_mode=dbg
justinjfu commented 2 weeks ago

cc @yashk2810 perhaps? Any ideas on what could be causing this build failure?

juuso-oskari commented 2 weeks ago

@yashk2810 I managed to build by pulling the latest xla and jax repos. But I have another problem:

The reason why I wanted to build the jaxlib from source is that I want to debug the xla backend that gets called when we run a jax script.

I try to debug:

gdb python
(gdb) break InitializePjitFunction # add breakpoint at one of the cpp functions that gets called
(gdb) run simple_jax_script.py

...

Thread 1 "python" hit Breakpoint 1, jax::(anonymous namespace)::InitializePjitFunction (fn_obj=0x7fffd6f3c310, function_name="fft", fun=std::optional<nanobind::callable> = {...}, cache_miss=..., static_argnums=std::vector of length 2, capacity 2 = {...}, static_argnames=std::vector of length 2, capacity 2 = {...}, donate_argnums=std::vector of length 0, capacity 0, pytree_registry=std::shared_ptr<xla::PyTreeRegistry> (use count 1, weak count 1) = {...}, shard_arg_fallback=..., cache=std::shared_ptr<jax::(anonymous namespace)::PjitFunctionCache> (use count 1, weak count 0) = {...}) at external/xla/xla/python/pjit.cc:1044
1044    external/xla/xla/python/pjit.cc: No such file or directory.
(gdb) 

So it cannot find the xla source files (which would be located in /xla in my case, instead it tries to search that external (?)). How could I build the jaxlib so that it would link to the original c++ source files still?