Open yuyangshu opened 1 year ago
Hey @yuyangshu, currently jax
does expose define_bool_state
:
https://github.com/google/jax/blob/main/jax/_src/config.py#L174
Can you try using flax==0.6.11
?
I have same issue when I try to run with CUDA. Trying with flax==0.6.11 doesn't work for me.
Could you try upgrading flax to the latest version?
Can confirm this is an issue with jax==0.4.25
Downgrade to jax==0.4.24 solves this problem.
I have the same issue now (flax == 0.6.11 and jax==0.4.25). Downgrading to 0.4.24 now gives me a different error, to which the solution is to downgrade to 0.4.23. Let's see at what version this ends ...
@jiyuuchc what flax version are you using?
Attention: This is an external email. Use caution responding, opening attachments or clicking on links.
flax == 0.7.5
In fact, on jax==0.4.24, referencing jax.config.define_bool_state
already raise a DeprecationWarning:
DeprecationWarning: jax.config.define_bool_state is deprecated. Please use other libraries for configuration instead. <function define_bool_state at 0x7f0fadff95a0>
So it seems all in the plan. I don't know why this is a surprise to start with.
From: Marcus Chiam @.***>
Sent: Thursday, March 7, 2024 12:45 PM
To: google/flax
Cc: Yu,Ji; Mention
Subject: Re: [google/flax] AttributeError: module 'jax.config' has no attribute 'define_bool_state' when running big_vision
on tpu-vm-base
(Issue #3180)
Attention: This is an external email. Use caution responding, opening attachments or clicking on links.
@jiyuuchchttps://urldefense.com/v3/__https://github.com/jiyuuchc__;!!Cn_UX_p3!i-oEXZD7vYFxwFqBg5Mb1A4nY3l2wQI3--cgnnyPknoPzdoGjoKzL_PT15C3bmWEHk4W9CtBViw0Nj9CbKeUHA$ what flax version are you using?
— Reply to this email directly, view it on GitHubhttps://urldefense.com/v3/__https://github.com/google/flax/issues/3180*issuecomment-1984104784__;Iw!!Cn_UX_p3!i-oEXZD7vYFxwFqBg5Mb1A4nY3l2wQI3--cgnnyPknoPzdoGjoKzL_PT15C3bmWEHk4W9CtBViw0Nj9ChieI8A$, or unsubscribehttps://urldefense.com/v3/__https://github.com/notifications/unsubscribe-auth/AAKRPNRQK73NS6GYCVEN4F3YXCRUJAVCNFSM6AAAAAAZ66MP2CVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTSOBUGEYDINZYGQ__;!!Cn_UX_p3!i-oEXZD7vYFxwFqBg5Mb1A4nY3l2wQI3--cgnnyPknoPzdoGjoKzL_PT15C3bmWEHk4W9CtBViw0Nj-s5WoOuw$. You are receiving this because you were mentioned.Message ID: @.***>
@Alxmrphi can you try upgrading to flax==0.7.5 or higher?
I need to specify I want to use Cuda. The only command that installs a Jax version that finds my GPUs is the standard
pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
.
But this command installs jax 0.4.26 with the discussed error. When I manually downgrade to let's say 0.4.22 (pip install jax==0.4.22
), I again get a version that doesn't find my GPU...
Any idea on what to do?
@sudo-Boris Could you try specifying the version like so:
pip install "jax[cuda12_pip]"==0.4.22 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
This has previously worked for me.
@Alxmrphi Thank you for the quick reply!
I just can't seem to find a version of cuda support jax that works with my CUDA version (12.1).
For 0.4.23. I get WARNING: jax 0.4.23 does not provide the extra 'cuda12-pip'
at installation and when executing code I get jaxlib.xla_extension.XlaRuntimeError: INTERNAL: XLA requires ptxas version 11.8 or higher
.
For 0.4.24, I get the same warning at installation and the following annoying warning when executing the same code
W external/xla/xla/service/gpu/nvptx_compiler.cc:744] The NVIDIA driver's CUDA version is 12.1 which is older than the ptxas CUDA version (12.4.131). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
I guess I need to tweak around with the different library versions to find a combination that works.
Thank you :)
Oh, I've been there and your situation was what I was dealing with not so long ago (as you can see by my comments earlier up in this thread). I can let you know why I'm currently using that works with cuda 12.2 so I think hopefully would work with cuda 12.1 because I'm using an earlier version of JAX than you (patch version 20 instead of 22). This is from a HPC job so I'll just copy my environment setup script header. Just take the version numbers from the settings below and try those ones out.
module purge
module load gcc/11.3.0
module load python/3.10
module load cuda/12.2
module load cudnn
module load scipy-stack
virtualenv --no-download $ENVDIR
source $ENVDIR/bin/activate
pip install --no-index torch torchvision
pip install --no-index flax==0.7.5+computecanada
pip install --no-index wandb
pip install jax==0.4.20+computecanada --no-index
pip install orbax_checkpoint==0.5.2+computecanada --no-index
If it doesn't work for you then I guess you just have to tweak the versions. I think you'll still need to use the jax[cuda12_pip]==version style of installation in a regular terminal setup.
Best of luck :)
I had the same issue and in the end these versions worked for me :
jaxlib: 0.4.26+cuda12.cudnn89 (installed from https://storage.googleapis.com/jax-releases/jax_cuda_releases.html ) jax: 0.4.26 flax: 0.8.4
these versions fixed the titled problem, but raised:
RuntimeError: Failed to import transformers.models.bart.modeling_flax_bart because of the following error (look up to see its traceback): module 'jax.numpy' has no attribute 'DeviceArray'
@eleninisioti Thanks, that combiantion worked for me!
jaxlib: 0.4.26+cuda12.cudnn89 (installed from https://storage.googleapis.com/jax-releases/jax_cuda_releases.html ) jax: 0.4.26 flax: 0.8.4
System information
pip show flax jax jaxlib
: flax: 0.6.7, jax: 0.4.13, jaxlib: 0.4.13Problem you have encountered:
Hello, I am seeing
when running the
big_vision
library ontpu-vm-base
.I saw a similar issue in another library where the issue seems to be resolved by fixing
jax
version to0.4.9
, but when I attempted that it did not work.I also tried fixing the versions of all packages in the
requirements.txt
ofbig_vision
, i.e.at the same time when fixing
jax
to0.4.9
, but that did not work either.I had to use a full
requirements.txt
obtained from runningpip freeze
in a local venv created on 2023-03-26 to get the library running on TPU again.What you expected to happen:
I was able to run
big_vision
ontpu-vm-base
on av3-8
TPU node without fixing any package versions as late as 2023-05-24.Logs, error messages, etc:
Steps to reproduce:
big_vision
locallyCreate a TPU node
Upload
big_vision
to the TPU and start training