google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.11k stars 645 forks source link

AttributeError: module 'jax.config' has no attribute 'define_bool_state' when running `big_vision` on `tpu-vm-base` #3180

Open yuyangshu opened 1 year ago

yuyangshu commented 1 year ago

System information

Problem you have encountered:

Hello, I am seeing

AttributeError: module 'jax.config' has no attribute 'define_bool_state'

when running the big_vision library on tpu-vm-base.

I saw a similar issue in another library where the issue seems to be resolved by fixing jax version to 0.4.9, but when I attempted that it did not work.

I also tried fixing the versions of all packages in the requirements.txt of big_vision, i.e.

absl-py==1.4.0
clu==0.0.8
einops==0.6.0
flax==0.6.7
git+https://github.com/google/flaxformer
git+https://github.com/deepmind/optax.git
git+https://github.com/akolesnikoff/panopticapi.git@mute
overrides==7.3.1
tensorflow==2.12.0
tfds-nightly==4.8.3.dev202303250044
tensorflow-addons==0.19.0
tensorflow-text==2.12.0
tensorflow-gan==2.1.0

at the same time when fixing jax to 0.4.9, but that did not work either.

I had to use a full requirements.txt obtained from running pip freeze in a local venv created on 2023-03-26 to get the library running on TPU again.

What you expected to happen:

I was able to run big_vision on tpu-vm-base on a v3-8 TPU node without fixing any package versions as late as 2023-05-24.

Logs, error messages, etc:

Installing collected packages: libtpu-nightly, zipp, numpy, scipy, opt-einsum, ml-dtypes, importlib-metadata, jaxlib, jax
Successfully installed importlib-metadata-6.7.0 jax-0.4.13 jaxlib-0.4.13 libtpu-nightly-0.1.dev20230622 ml-dtypes-0.2.0 numpy-1.24.4 opt-einsum-3.3.0 scipy-1.10.1 zipp-3.15.0
Collecting git+https://github.com/google/flaxformer (from -r big_vision/requirements.txt (line 5))
  Cloning https://github.com/google/flaxformer to /tmp/pip-req-build-925ai1ze
  Running command git clone --filter=blob:none --quiet https://github.com/google/flaxformer /tmp/pip-req-build-925ai1ze
  Resolved https://github.com/google/flaxformer to commit 9adaa4467cf17703949b9f537c3566b99de1b416
  Installing build dependencies: started
  Installing build dependencies: finished with status 'done'
  Getting requirements to build wheel: started
  Getting requirements to build wheel: finished with status 'done'
  Preparing metadata (pyproject.toml): started
  Preparing metadata (pyproject.toml): finished with status 'done'

[omitted]

Collecting flax==0.6.7 (from -r big_vision/requirements.txt (line 4))
  Downloading flax-0.6.7-py3-none-any.whl (214 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 214.2/214.2 kB 28.6 MB/s eta 0:00:00

[omitted]

Building wheels for collected packages: flaxformer, optax, panopticapi, ml-collections, promise
  Building wheel for flaxformer (pyproject.toml): started
  Building wheel for flaxformer (pyproject.toml): finished with status 'done'
  Created wheel for flaxformer: filename=flaxformer-0.8.1-py3-none-any.whl size=321948 sha256=df38d4209289e8a71a245b56f95490ec0ce9c2bbfaa164fd00d1b7e2f80b5869

[omitted]

Successfully installed MarkupSafe-2.1.3 Pillow-10.0.0 PyYAML-6.0 absl-py-1.4.0 aqtp-0.1.1 array-record-0.4.0 astunparse-1.6.3 cached_property-1.5.2 cachetools-5.3.1 certifi-2023.5.7 charset-normalizer-3.1.0 chex-0.1.7 click-8.1.3 cloudpickle-2.2.1 clu-0.0.8 contextlib2-21.6.0 dacite-1.8.1 decorator-5.1.1 dm-tree-0.1.8 einops-0.6.0 etils-1.3.0 flatbuffers-23.5.26 flax-0.6.7 flaxformer-0.8.1 gast-0.4.0 google-auth-2.21.0 google-auth-oauthlib-1.0.0 google-pasta-0.2.0 googleapis-common-protos-1.59.1 grpcio-1.56.0 h5py-3.9.0 idna-3.4 importlib-resources-5.12.0 keras-2.12.0 libclang-16.0.0 markdown-3.4.3 markdown-it-py-3.0.0 mdurl-0.1.2 ml-collections-0.1.1 msgpack-1.0.5 nest_asyncio-1.5.6 numpy-1.23.5 oauthlib-3.2.2 optax-0.1.5 orbax-0.1.7 overrides-7.3.1 packaging-23.1 panopticapi-0.1 promise-2.3 protobuf-4.23.3 psutil-5.9.5 pyasn1-0.5.0 pyasn1-modules-0.3.0 pygments-2.15.1 requests-2.31.0 requests-oauthlib-1.3.1 rich-13.4.2 rsa-4.9 six-1.16.0 tensorboard-2.12.3 tensorboard-data-server-0.7.1 tensorflow-2.12.0 tensorflow-addons-0.19.0 tensorflow-datasets-4.9.2 tensorflow-estimator-2.12.0 tensorflow-gan-2.1.0 tensorflow-hub-0.13.0 tensorflow-io-gcs-filesystem-0.32.0 tensorflow-metadata-1.13.1 tensorflow-probability-0.20.1 tensorflow-text-2.12.0 tensorstore-0.1.40 termcolor-2.3.0 tfds-nightly-4.8.3.dev202303250044 toml-0.10.2 toolz-0.12.0 tqdm-4.65.0 typeguard-4.0.0 typing-extensions-4.7.1 urllib3-1.26.16 werkzeug-2.3.6 wheel-0.40.0 wrapt-1.14.1
Traceback (most recent call last):
  File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/yuyang/big_vision/train.py", line 28, in <module>
    import big_vision.evaluators.common as eval_common
  File "/home/yuyang/big_vision/evaluators/common.py", line 22, in <module>
    import flax
  File "/home/yuyang/bv_venv/lib/python3.8/site-packages/flax/__init__.py", line 18, in <module>
    from .configurations import (
  File "/home/yuyang/bv_venv/lib/python3.8/site-packages/flax/configurations.py", line 93, in <module>
    flax_filter_frames = define_bool_state(
  File "/home/yuyang/bv_venv/lib/python3.8/site-packages/flax/configurations.py", line 42, in define_bool_state
    return jax_config.define_bool_state('flax_' + name, default, help)
AttributeError: module 'jax.config' has no attribute 'define_bool_state'

Steps to reproduce:

  1. Check out big_vision locally
git@github.com:google-research/big_vision.git
  1. Create a TPU node

    gcloud compute tpus tpu-vm create $VM_NAME --zone=$ZONE --accelerator-type=v3-8 --version=tpu-vm-base 
  2. Upload big_vision to the TPU and start training

gcloud compute tpus tpu-vm scp --recurse big_vision/big_vision $VM_NAME: --zone=$ZONE --worker=all
gcloud compute tpus tpu-vm ssh $VM_NAME --zone=$ZONE --worker=all --command "TFDS_DATA_DIR=gs://$BUCKET_NAME/tensorflow_datasets bash big_vision/run_tpu.sh big_vision.train --config big_vision/configs/vit_s16_i1k.py --workdir gs://$BUCKET_NAME/workdirs/`date '+%m-%d_%H%M'`"
cgarciae commented 1 year ago

Hey @yuyangshu, currently jax does expose define_bool_state:

https://github.com/google/jax/blob/main/jax/_src/config.py#L174

Can you try using flax==0.6.11?

hwh7 commented 1 year ago

I have same issue when I try to run with CUDA. Trying with flax==0.6.11 doesn't work for me.

chiamp commented 1 year ago

Could you try upgrading flax to the latest version?

jiyuuchc commented 8 months ago

Can confirm this is an issue with jax==0.4.25

Downgrade to jax==0.4.24 solves this problem.

Alxmrphi commented 8 months ago

I have the same issue now (flax == 0.6.11 and jax==0.4.25). Downgrading to 0.4.24 now gives me a different error, to which the solution is to downgrade to 0.4.23. Let's see at what version this ends ...

chiamp commented 8 months ago

@jiyuuchc what flax version are you using?

jiyuuchc commented 8 months ago

Attention: This is an external email. Use caution responding, opening attachments or clicking on links.

flax == 0.7.5

In fact, on jax==0.4.24, referencing jax.config.define_bool_state already raise a DeprecationWarning:

DeprecationWarning: jax.config.define_bool_state is deprecated. Please use other libraries for configuration instead. <function define_bool_state at 0x7f0fadff95a0>

So it seems all in the plan. I don't know why this is a surprise to start with.


From: Marcus Chiam @.***> Sent: Thursday, March 7, 2024 12:45 PM To: google/flax Cc: Yu,Ji; Mention Subject: Re: [google/flax] AttributeError: module 'jax.config' has no attribute 'define_bool_state' when running big_vision on tpu-vm-base (Issue #3180)

Attention: This is an external email. Use caution responding, opening attachments or clicking on links.

@jiyuuchchttps://urldefense.com/v3/__https://github.com/jiyuuchc__;!!Cn_UX_p3!i-oEXZD7vYFxwFqBg5Mb1A4nY3l2wQI3--cgnnyPknoPzdoGjoKzL_PT15C3bmWEHk4W9CtBViw0Nj9CbKeUHA$ what flax version are you using?

— Reply to this email directly, view it on GitHubhttps://urldefense.com/v3/__https://github.com/google/flax/issues/3180*issuecomment-1984104784__;Iw!!Cn_UX_p3!i-oEXZD7vYFxwFqBg5Mb1A4nY3l2wQI3--cgnnyPknoPzdoGjoKzL_PT15C3bmWEHk4W9CtBViw0Nj9ChieI8A$, or unsubscribehttps://urldefense.com/v3/__https://github.com/notifications/unsubscribe-auth/AAKRPNRQK73NS6GYCVEN4F3YXCRUJAVCNFSM6AAAAAAZ66MP2CVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTSOBUGEYDINZYGQ__;!!Cn_UX_p3!i-oEXZD7vYFxwFqBg5Mb1A4nY3l2wQI3--cgnnyPknoPzdoGjoKzL_PT15C3bmWEHk4W9CtBViw0Nj-s5WoOuw$. You are receiving this because you were mentioned.Message ID: @.***>

chiamp commented 8 months ago

@Alxmrphi can you try upgrading to flax==0.7.5 or higher?

sudo-Boris commented 6 months ago

I need to specify I want to use Cuda. The only command that installs a Jax version that finds my GPUs is the standard pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html.

But this command installs jax 0.4.26 with the discussed error. When I manually downgrade to let's say 0.4.22 (pip install jax==0.4.22), I again get a version that doesn't find my GPU...

Any idea on what to do?

Alxmrphi commented 6 months ago

@sudo-Boris Could you try specifying the version like so:

pip install "jax[cuda12_pip]"==0.4.22 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

This has previously worked for me.

sudo-Boris commented 6 months ago

@Alxmrphi Thank you for the quick reply!

I just can't seem to find a version of cuda support jax that works with my CUDA version (12.1).

For 0.4.23. I get WARNING: jax 0.4.23 does not provide the extra 'cuda12-pip' at installation and when executing code I get jaxlib.xla_extension.XlaRuntimeError: INTERNAL: XLA requires ptxas version 11.8 or higher.

For 0.4.24, I get the same warning at installation and the following annoying warning when executing the same code W external/xla/xla/service/gpu/nvptx_compiler.cc:744] The NVIDIA driver's CUDA version is 12.1 which is older than the ptxas CUDA version (12.4.131). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.

I guess I need to tweak around with the different library versions to find a combination that works.

Thank you :)

Alxmrphi commented 6 months ago

Oh, I've been there and your situation was what I was dealing with not so long ago (as you can see by my comments earlier up in this thread). I can let you know why I'm currently using that works with cuda 12.2 so I think hopefully would work with cuda 12.1 because I'm using an earlier version of JAX than you (patch version 20 instead of 22). This is from a HPC job so I'll just copy my environment setup script header. Just take the version numbers from the settings below and try those ones out.

module purge
module load gcc/11.3.0
module load python/3.10
module load cuda/12.2
module load cudnn
module load scipy-stack

virtualenv --no-download $ENVDIR
source $ENVDIR/bin/activate

pip install --no-index torch torchvision
pip install --no-index flax==0.7.5+computecanada
pip install --no-index wandb
pip install jax==0.4.20+computecanada --no-index
pip install orbax_checkpoint==0.5.2+computecanada --no-index

If it doesn't work for you then I guess you just have to tweak the versions. I think you'll still need to use the jax[cuda12_pip]==version style of installation in a regular terminal setup.

Best of luck :)

eleninisioti commented 4 months ago

I had the same issue and in the end these versions worked for me :

jaxlib: 0.4.26+cuda12.cudnn89 (installed from https://storage.googleapis.com/jax-releases/jax_cuda_releases.html ) jax: 0.4.26 flax: 0.8.4

shivanraptor commented 2 months ago

these versions fixed the titled problem, but raised:

RuntimeError: Failed to import transformers.models.bart.modeling_flax_bart because of the following error (look up to see its traceback): module 'jax.numpy' has no attribute 'DeviceArray'

malcobak commented 1 month ago

@eleninisioti Thanks, that combiantion worked for me!

jaxlib: 0.4.26+cuda12.cudnn89 (installed from https://storage.googleapis.com/jax-releases/jax_cuda_releases.html ) jax: 0.4.26 flax: 0.8.4