Closed BobH233 closed 1 month ago
I discovered that the dependency chain is as follows: flax==0.7.5
depends on orbax-checkpoint==Any
, and all versions of orbax-checkpoint
seem to depend on jax>=0.4.26
, which is causing the conflict. I'm not sure how to resolve this issue.
Solved by using CUDA=12.3, and following upgrading:
pip install "jax[cuda12_pip]"==0.4.26 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install flax==0.8.4
pip install jax==0.4.26
Hello,
I'm encountering issues related to the dependencies between JAX and Orbax in my current setup. My environment uses CUDA 11.8, and after running
pip install -e .
, I configured the GPU version of JAX with the commandpip install -U "jax[cuda11]"
. The process of running./scripts/run_all_unit_tests.sh
completed without errors. However, during the training and saving stages, I encountered the following error:After investigating, I found that the code in orbax-checkpoint requires JAX version 0.4.6 or higher. However, when I manually upgraded JAX and JAXlib to version 0.4.6, I received the following error:
As a result, the training code could not run correctly:
I would appreciate more detailed guidance on how to properly configure my environment. I have already tried using CUDA versions 12.3, 11.8, and even 11.6, but none of them worked successfully.
If possible, could you please provide some assistance or point me toward the correct setup steps?
Thank you in advance for your help!