Open juuso-oskari opened 2 weeks ago
cc @yashk2810 perhaps? Any ideas on what could be causing this build failure?
@yashk2810 I managed to build by pulling the latest xla and jax repos. But I have another problem:
The reason why I wanted to build the jaxlib from source is that I want to debug the xla backend that gets called when we run a jax script.
I try to debug:
gdb python
(gdb) break InitializePjitFunction # add breakpoint at one of the cpp functions that gets called
(gdb) run simple_jax_script.py
...
Thread 1 "python" hit Breakpoint 1, jax::(anonymous namespace)::InitializePjitFunction (fn_obj=0x7fffd6f3c310, function_name="fft", fun=std::optional<nanobind::callable> = {...}, cache_miss=..., static_argnums=std::vector of length 2, capacity 2 = {...}, static_argnames=std::vector of length 2, capacity 2 = {...}, donate_argnums=std::vector of length 0, capacity 0, pytree_registry=std::shared_ptr<xla::PyTreeRegistry> (use count 1, weak count 1) = {...}, shard_arg_fallback=..., cache=std::shared_ptr<jax::(anonymous namespace)::PjitFunctionCache> (use count 1, weak count 0) = {...}) at external/xla/xla/python/pjit.cc:1044
1044 external/xla/xla/python/pjit.cc: No such file or directory.
(gdb)
So it cannot find the xla source files (which would be located in /xla in my case, instead it tries to search that external (?)). How could I build the jaxlib so that it would link to the original c++ source files still?
Description
I try to build the jaxlib with debug symbols for the xla with the following command:
The build goes fine all the way up till the end when it tries to link the xla_extension.so:
I don't get this error if I don't pass the argument --enable_cuda so probably it has something to do with the CUDA. But then again, passing --enable_cuda works if I don't try to build with debug symbols (so not passing --bazel_options=--compilation_mode=dbg).
System info (python version, jaxlib version, accelerator, etc.)
python version: Python 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] on linux jax: latest commit https://github.com/google/jax.git xla:latest commit https://github.com/openxla/xla accelerator (info from nvidia-smi): NVIDIA-SMI 560.28.03, Driver Version: 560.28.03, CUDA Version: 12.6, NVIDIA GeForce RTX 4080 Laptop GPU
To reproduce the error with a docker container: