pangeo-data / pangeo-docker-images

Docker Images For Pangeo Jupyter Environment
https://pangeo-docker-images.readthedocs.io
MIT License
123 stars 90 forks source link

Flax needs to be upgraded in the tensorflow/jax image #489

Closed dhruvbalwada closed 3 months ago

dhruvbalwada commented 11 months ago

Describe the bug Current version of flax (0.6.1) on the image does not work properly with the jax version (0.4.13). T

To Reproduce Issue can be reproduced by doing from flax.training import checkpoints, which will give the error ModuleNotFoundError: No module named 'jax.experimental.global_device_array'. This has been discussed in https://github.com/google/flax/issues/3087.

Expected behavior Flax should be importable.

Infrastructure (Where you are running this image):

Solution At the moment I solve this doing pip install flax==0.6.10

weiji14 commented 11 months ago

The latest version of flax on conda-forge is still 0.6.1, see https://anaconda.org/conda-forge/flax/files. It seems like the package has fallen out of maintenance for a while, but there's some effort to bump up to v0.7.x, see e.g. https://github.com/conda-forge/flax-feedstock/pull/29

We could either wait for the updates on flax-feedstock, or use pip to install flax in the ml-noteboook/environment.yml file as a temporary measure.

weiji14 commented 10 months ago

OK, looks like flax=0.7.4 and flax=0.7.5 is available on conda-forge now. However, there's some dependency conflict with tensorflow. I tried with this environment.yml file:

name: pangeo
channels:
 - conda-forge
 - nodefaults
dependencies:
 - flax>=0.7.0
 - jax
 - jupyterlab-nvdashboard
 - keras-cv
 - tensorflow>=2.13.1=*cuda112*

and got this traceback:

Locking dependencies for ['linux-64']...
INFO:conda_lock.conda_solver:linux-64 using specs ['flax >=0.7.0', 'jax', 'jupyterlab-nvdashboard', 'keras-cv', 'tensorflow >=2.13.1 *cuda112*', 'adlfs', 'argopy', 'awscli', 'boto3', 'bottleneck', 'cartopy', 'cdsapi', 'cfgrib', 'ciso', 'cmocean', 'dask-ml', 'datashader', 'descartes', 'earthaccess', 'eofs', 'erddapy', 'esmpy', 'fastjmd95', 'flox', 'fsspec', 'gcm_filters', 'gcsfs', 'gh', 'gh-scoped-creds', 'geocube', 'geopandas', 'geopy', 'geoviews-core', 'git-lfs', 'gsw', 'h5netcdf', 'h5py', 'holoviews', 'hvplot', 'intake', 'intake-esm', 'intake-geopandas', 'intake-stac', 'intake-xarray', 'ipykernel', 'ipyleaflet', 'ipytree', 'ipywidgets', 'jupyterlab-git', 'jupyter-panel-proxy', 'jupyter-resource-usage', 'kerchunk', 'lxml', 'lz4', 'matplotlib-base', 'metpy', 'nb_conda_kernels', 'nbstripout', 'nc-time-axis', 'netcdf4', 'numbagg', 'numcodecs', 'numpy', 'numpy_groupies', 'odc-stac', 'pandas', 'panel', 'parcels', 'param', 'pop-tools', 'pyarrow', 'pycamhd', 'pydap', 'pystac', 'pystac-client', 'python-blosc', 'python-gist', 'python-graphviz', 'rasterio', 'rechunker', 'rio-cogeo', 'rioxarray', 's3fs', 'satpy', 'scikit-image', 'scikit-learn', 'scipy', 'seaborn', 'sparse', 'snakeviz', 'stackstac', 'tiledb-py', 'timezonefinder', 'xarray', 'xarrayutils', 'xarray-datatree', 'xarray_leaflet', 'xarray-spatial', 'xbatcher', 'xcape', 'xclim', 'xesmf', 'xgboost', 'xgcm', 'xhistogram', 'xmip', 'xmitgcm', 'xpublish', 'xrft', 'xskillscore', 'zarr', 'python 3.11.*', 'pangeo-notebook 2023.11.11.*', 'pip']
Failed to parse json, Expecting value: line 1 column 1 (char 0)
Could not lock the environment for platform linux-64
Could not solve for environment specs
The following packages are incompatible
├─ flax >=0.7.0  is installable and it requires
│  └─ tensorstore   with the potential options
│     ├─ tensorstore [0.1.44|0.1.46|0.1.47|0.1.48] would require
│     │  └─ libprotobuf [>=4.24.3,<4.24.4.0a0 |>=4.24.4,<4.24.5.0a0 ], which requires
│     │     └─ libabseil >=20230802.1,<20230803.0a0 , which can be installed;
│     └─ tensorstore 0.1.44 would require
│        └─ libprotobuf >=4.23.4,<4.23.5.0a0  with the potential options
│           ├─ libprotobuf 4.23.4, which can be installed;
│           └─ libprotobuf 4.23.4 would require
│              └─ libabseil >=20230802.0,<20230803.0a0 , which can be installed;
└─ tensorflow >=2.13.1 *cuda112* is uninstallable because it requires
   └─ tensorflow-base [2.13.1 cuda112py310hbb601f2_1|2.13.1 cuda112py311h8bdbb6c_1|2.13.1 cuda112py38h79651c7_1|2.13.1 cuda112py39h85a252b_1], which requires
      ├─ libabseil >=20230125.3,<20230126.0a0 , which conflicts with any installable versions previously reported;
      └─ libgrpc >=1.54.3,<1.55.0a0 , which requires
         └─ libprotobuf >=3.21.12,<3.22.0a0 , which conflicts with any installable versions previously reported.

It seems like flax -> tensorstore is depending on libprotobuf~=4.24, but tensorflow-base -> libgrpc is depending on libprotobuf~=3.21. Seems like a longstanding issue at https://github.com/conda-forge/tensorflow-feedstock/issues/288, and the migration at ~either https://github.com/conda-forge/tensorflow-feedstock/pull/347 or~ https://github.com/conda-forge/tensorflow-feedstock/pull/342 might help.

Edit: ~Might need to wait for tensorflow 2.15.0 at https://github.com/conda-forge/tensorflow-feedstock/pull/353?~ Nope, conda-forge's tensorflow=2.15.0 doesn't help. They still need to handle the libprotobuf issue, wait for the bot re-run after https://github.com/conda-forge/tensorflow-feedstock/pull/359. ~See https://github.com/conda-forge/tensorflow-feedstock/pull/361 :pray:~ (PR was merged, but package later marked as broken according to https://github.com/conda-forge/tensorflow-feedstock/pull/367#issuecomment-1890757429 :sweat_smile:).

weiji14 commented 7 months ago

Looks like we might need to upgrade from CUDA 11.8 to 12 to get a newer version of tensorflow=2.15.0 from conda-forge with libprotobuf~=4.24 that works with flax>=0.7.4, see https://github.com/conda-forge/tensorflow-feedstock/pull/367#issuecomment-1890784430.

dhruvbalwada commented 7 months ago

@yuvipanda and @jbusecke - just getting your attention here, since my hacky workflows seems to have stopped working today.

Recently to work with jax and appropriate version of flax, I would do 2 steps on the leap hub:

However, today morning this leads to a new error. After following these steps, I get the error: RuntimeError: jaxlib is version 0.4.12, but this version of jax requires version >= 0.4.19. when I try to do an import jax.

jbusecke commented 7 months ago

Hey @dhruvbalwada, I suspect this is due to the recent update of the pangeo-docker-image on the LEAP hub.

To unblock you for now, I recommend you manually run from an older image (the LEAP docs provide instructions).

But this does not change the core problem here I think. Anything I could help/test to contribute here @weiji14?

weiji14 commented 7 months ago

Right, looks like we'll need to expedite the upgrade to CUDA 12 then as mentioned at https://github.com/pangeo-data/pangeo-docker-images/issues/489#issuecomment-1911393877. Let me open a PR for that (got some free time today), and then we'll be able to upgrade to newer tenforflow/flax versions.

weiji14 commented 7 months ago

Ok, not as simple as I thought. I tried running conda-lock to create a lockfile with a newer version of tensorflow=2.15.0 and flax>=0.8.0 with CUDA 12.0, but it errors with:

The following packages are incompatible
├─ flax >=0.8.0  is installable and it requires
│  └─ jax >=0.4.11  with the potential options
│     ├─ jax 0.4.11 would require
│     │  └─ jaxlib >=0.4.7  with the potential options
│     │     ├─ jaxlib [0.4.10|0.4.11|0.4.12|0.4.14|0.4.9] would require
│     │     │  └─ libgrpc >=1.54.2,<1.55.0a0 , which can be installed;
│     │     ├─ jaxlib 0.4.12 would require
│     │     │  └─ libgrpc >=1.56.0,<1.57.0a0 , which can be installed;
│     │     ├─ jaxlib 0.4.14 would require
│     │     │  └─ libgrpc >=1.56.2,<1.57.0a0 , which can be installed;
│     │     ├─ jaxlib [0.4.14|0.4.18|0.4.19] would require
│     │     │  └─ libgrpc >=1.58.1,<1.59.0a0 , which can be installed;
│     │     ├─ jaxlib 0.4.14 would require
│     │     │  └─ libgrpc >=1.57.0,<1.58.0a0 , which can be installed;
│     │     ├─ jaxlib [0.4.20|0.4.23] would require
│     │     │  └─ libgrpc >=1.58.2,<1.59.0a0 , which can be installed;
│     │     ├─ jaxlib 0.4.7 would require
│     │     │  └─ libgrpc >=1.52.1,<1.53.0a0 , which can be installed;
│     │     └─ jaxlib 0.4.7 would require
│     │        └─ libgrpc >=1.54.0,<1.55.0a0 , which can be installed;
│     ├─ jax [0.4.12|0.4.13|0.4.14] would require
│     │  └─ jaxlib >=0.4.11 , which can be installed (as previously explained);
│     ├─ jax [0.4.16|0.4.17|0.4.19|0.4.20] would require
│     │  └─ jaxlib >=0.4.14 , which can be installed (as previously explained);
│     └─ jax [0.4.21|0.4.23] would require
│        └─ jaxlib >=0.4.19 , which can be installed (as previously explained);
└─ tensorflow >=2.15.0 *cuda120* is not installable because it requires
   └─ tensorflow-base [2.15.0 cuda120py310heceb7ac_2|2.15.0 cuda120py310heceb7ac_3|...|2.15.0 cuda120py39hf42b710_3], which requires
      └─ libgrpc >=1.59.3,<1.60.0a0 , which conflicts with any installable versions previously reported.

It looks like we'll need to wait for jaxlib to support CUDA 12 (https://github.com/conda-forge/jaxlib-feedstock/issues/223, https://github.com/conda-forge/jaxlib-feedstock/pull/218), and also be rebuilt with libprotobuf 4.24 (https://github.com/conda-forge/jaxlib-feedstock/pull/221).

dhruvbalwada commented 7 months ago

Do you think for now a slightly older version may be enough? maybe flax>=0.7?

weiji14 commented 7 months ago

Nope, flax>=0.7.0 doesn't work either

├─ flax >=0.7.0  is installable and it requires
│  └─ jax >=0.4.11  with the potential options
│     ├─ jax 0.4.11 would require
│     │  └─ jaxlib >=0.4.7  with the potential options
│     │     ├─ jaxlib [0.4.10|0.4.11|0.4.12|0.4.14|0.4.9] would require
│     │     │  └─ libgrpc >=1.54.2,<1.55.0a0 , which can be installed;
│     │     ├─ jaxlib 0.4.12 would require
│     │     │  └─ libgrpc >=1.56.0,<1.57.0a0 , which can be installed;
│     │     ├─ jaxlib 0.4.14 would require
│     │     │  └─ libgrpc >=1.56.2,<1.57.0a0 , which can be installed;
│     │     ├─ jaxlib [0.4.14|0.4.18|0.4.19] would require
│     │     │  └─ libgrpc >=1.58.1,<1.59.0a0 , which can be installed;
│     │     ├─ jaxlib 0.4.14 would require
│     │     │  └─ libgrpc >=1.57.0,<1.58.0a0 , which can be installed;
│     │     ├─ jaxlib [0.4.20|0.4.23] would require
│     │     │  └─ libgrpc >=1.58.2,<1.59.0a0 , which can be installed;
│     │     ├─ jaxlib 0.4.7 would require
│     │     │  └─ libgrpc >=1.52.1,<1.53.0a0 , which can be installed;
│     │     └─ jaxlib 0.4.7 would require
│     │        └─ libgrpc >=1.54.0,<1.55.0a0 , which can be installed;
│     ├─ jax [0.4.12|0.4.13|0.4.14] would require
│     │  └─ jaxlib >=0.4.11 , which can be installed (as previously explained);
│     ├─ jax [0.4.16|0.4.17|0.4.19|0.4.20] would require
│     │  └─ jaxlib >=0.4.14 , which can be installed (as previously explained);
│     └─ jax [0.4.21|0.4.23] would require
│        └─ jaxlib >=0.4.19 , which can be installed (as previously explained);
└─ tensorflow >=2.15.0 *cuda120* is not installable because it requires
   └─ tensorflow-base [2.15.0 cuda120py310heceb7ac_2|2.15.0 cuda120py310heceb7ac_3|...|2.15.0 cuda120py39hf42b710_3], which requires
      └─ libgrpc >=1.59.3,<1.60.0a0 , which conflicts with any installable versions previously reported.

I've also tried older version combinations with CUDA 11.2 and tensorflow 2.13.x last year (see all my crossed out links in https://github.com/pangeo-data/pangeo-docker-images/issues/489#issuecomment-1807319576), but they all don't work. We really need to get all the tensorflow/jax libraries to align on the correct version of libprotobuf in conda-forge.

dhruvbalwada commented 7 months ago

The new hack that is working is :

pip install 'flax==0.7.2' 'jax<=0.4.13' 'ml_dtypes==0.2.0'
mamba install cuda-nvcc==11.6.* -c nvidia

Hopefully alignment will come in near future.