google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
5.79k stars 609 forks source link

imagenet example needs a refreshed requirements.txt #3950

Open sycamoreoak opened 1 month ago

sycamoreoak commented 1 month ago

it seems like the flax examples could use a version bump?

System information

pip show flax jax jaxlib
Name: flax
Version: 0.6.5
Summary: Flax: A neural network library for JAX designed for flexibility
Home-page: https://github.com/google/flax
Author: Flax team
Author-email: flax-dev@google.com
License: 
Location: ~/src/python/flax/examples/imagenet/.venv/lib/python3.10/site-packages
Requires: jax, matplotlib, msgpack, numpy, optax, orbax, PyYAML, rich, tensorstore, typing-extensions
Required-by: clu
---
Name: jax
Version: 0.4.28
Summary: Differentiate, compile, and transform Numpy code.
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: jax-dev@google.com
License: Apache-2.0
Location: ~/src/python/flax/examples/imagenet/.venv/lib/python3.10/site-packages
Requires: ml-dtypes, numpy, opt-einsum, scipy
Required-by: chex, clu, flax, optax, orbax-checkpoint
---
Name: jaxlib
Version: 0.4.28+cuda12.cudnn89
Summary: XLA library for JAX
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: jax-dev@google.com
License: Apache-2.0
Location: ~/src/python/flax/examples/imagenet/.venv/lib/python3.10/site-packages
Requires: ml-dtypes, numpy, scipy
Required-by: chex, clu, optax, orbax-checkpoint

Problem you have encountered:

imagenet with venv: If I create a venv and install the requirements.txt, :

  File "~/src/python/flax/examples/imagenet/.venv/lib/python3.10/site-packages/flax/configurations.py", line 74, in <module>
    flax_filter_frames = define_bool_state(
  File "~/src/python/flax/examples/imagenet/.venv/lib/python3.10/site-packages/flax/configurations.py", line 40, in define_bool_state
    return jax_config.define_bool_state('flax_' + name, default, help)
AttributeError: 'Config' object has no attribute 'define_bool_state'

My pip list flax version reports 0.6.5 and the requirements.txt specifies 0.6.5

What you expected to happen:

imagenet trains

Logs, error messages, etc:

see above

Steps to reproduce:

create a fresh and up to date ubuntu install with an nvidia card, clone flax, and run the requirements.txt

sycamoreoak commented 3 weeks ago

OK this appears to work:

python -m venv imagenet_env
source ./imagenet_env/bin/activate
unset LD_LIBRARY_PATH
pip install "jax[cuda12]" flax tensorflow clu tensorflow-datasets
python main.py --workdir=./imagenet --config=configs/default.py

looks like requirements.txt may need to be freshened up. also note that the instructions in the TPU section of the README to pip install -e . in the flax folder as a method of installing flax seems to break things.

here is my pip freeze:

absl-py==2.1.0
array_record==0.5.1
astunparse==1.6.3
cachetools==5.3.3
certifi==2024.6.2
charset-normalizer==3.3.2
chex==0.1.86
click==8.1.7
clu==0.0.12
contextlib2==21.6.0
dm-tree==0.1.8
docstring_parser==0.16
etils==1.7.0
flatbuffers==24.3.25
flax==0.8.4
fsspec==2024.6.0
gast==0.4.0
google-auth==2.30.0
google-auth-oauthlib==1.0.0
google-pasta==0.2.0
grpcio==1.64.1
h5py==3.11.0
idna==3.7
immutabledict==4.2.0
importlib_resources==6.4.0
jax==0.4.29
jax-cuda12-pjrt==0.4.29
jax-cuda12-plugin==0.4.29
jaxlib==0.4.29
keras==2.13.1
libclang==18.1.1
Markdown==3.6
markdown-it-py==3.0.0
MarkupSafe==2.1.5
mdurl==0.1.2
ml-collections==0.1.1
ml-dtypes==0.4.0
msgpack==1.0.8
nest-asyncio==1.6.0
numpy==1.24.3
nvidia-cublas-cu12==12.5.2.13
nvidia-cuda-cupti-cu12==12.5.39
nvidia-cuda-nvcc-cu12==12.5.40
nvidia-cuda-runtime-cu12==12.5.39
nvidia-cudnn-cu12==9.1.1.17
nvidia-cufft-cu12==11.2.3.18
nvidia-cusolver-cu12==11.6.2.40
nvidia-cusparse-cu12==12.4.1.24
nvidia-nccl-cu12==2.21.5
nvidia-nvjitlink-cu12==12.5.40
oauthlib==3.2.2
opt-einsum==3.3.0
optax==0.2.2
orbax-checkpoint==0.5.16
packaging==24.1
promise==2.3
protobuf==3.20.3
psutil==5.9.8
pyarrow==16.1.0
pyasn1==0.6.0
pyasn1_modules==0.4.0
Pygments==2.18.0
PyYAML==6.0.1
requests==2.32.3
requests-oauthlib==2.0.0
rich==13.7.1
rsa==4.9
scipy==1.13.1
simple_parsing==0.1.5
six==1.16.0
tensorboard==2.13.0
tensorboard-data-server==0.7.2
tensorflow==2.13.1
tensorflow-datasets==4.9.6
tensorflow-estimator==2.13.0
tensorflow-io-gcs-filesystem==0.37.0
tensorflow-metadata==1.15.0
tensorstore==0.1.61
termcolor==2.4.0
toml==0.10.2
toolz==0.12.1
tqdm==4.66.4
typing_extensions==4.5.0
urllib3==2.2.1
Werkzeug==3.0.3
wrapt==1.16.0
zipp==3.19.2