jonbarron / camp_zipnerf

Apache License 2.0
651 stars 41 forks source link

Issues with JAX and Orbax Dependencies in CUDA 11.8 Environment #32

Closed BobH233 closed 1 month ago

BobH233 commented 1 month ago

Hello,

I'm encountering issues related to the dependencies between JAX and Orbax in my current setup. My environment uses CUDA 11.8, and after running pip install -e ., I configured the GPU version of JAX with the command pip install -U "jax[cuda11]". The process of running ./scripts/run_all_unit_tests.sh completed without errors. However, during the training and saving stages, I encountered the following error:

File "/path_to_anaconda/envs/my_env/lib/python3.11/site-packages/orbax/checkpoint/multihost/utils.py", line 90, in broadcast_one_to_some
    in_tree = jax.tree.map(pre_jit, in_tree)
              ^^^^^^^^
File "/path_to_anaconda/envs/my_env/lib/python3.11/site-packages/jax/_src/deprecations.py", line 53, in getattr
    raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax' has no attribute 'tree'

After investigating, I found that the code in orbax-checkpoint requires JAX version 0.4.6 or higher. However, when I manually upgraded JAX and JAXlib to version 0.4.6, I received the following error:

camp-zipnerf 0.0.2 requires jax==0.4.23, but you have jax 0.4.6 which is incompatible.

As a result, the training code could not run correctly:

Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/path_to_project/my_project/train.py", line 26, in <module>
    import flax
  File "/path_to_anaconda/envs/my_env/lib/python3.11/site-packages/flax/__init__.py", line 23, in <module>
    from . import core
  File "/path_to_anaconda/envs/my_env/lib/python3.11/site-packages/flax/core/__init__.py", line 15, in <module>
    from .axes_scan import broadcast as broadcast
  File "/path_to_anaconda/envs/my_env/lib/python3.11/site-packages/flax/core/axes_scan.py", line 22, in <module>
    from jax.extend import linear_util as lu
ModuleNotFoundError: No module named 'jax.extend'

I would appreciate more detailed guidance on how to properly configure my environment. I have already tried using CUDA versions 12.3, 11.8, and even 11.6, but none of them worked successfully.

If possible, could you please provide some assistance or point me toward the correct setup steps?

Thank you in advance for your help!

BobH233 commented 1 month ago

I discovered that the dependency chain is as follows: flax==0.7.5 depends on orbax-checkpoint==Any, and all versions of orbax-checkpoint seem to depend on jax>=0.4.26, which is causing the conflict. I'm not sure how to resolve this issue.

BobH233 commented 1 month ago

Solved by using CUDA=12.3, and following upgrading:

pip install "jax[cuda12_pip]"==0.4.26 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install flax==0.8.4
pip install jax==0.4.26