Closed dhruvbalwada closed 3 months ago
The latest version of flax
on conda-forge is still 0.6.1
, see https://anaconda.org/conda-forge/flax/files. It seems like the package has fallen out of maintenance for a while, but there's some effort to bump up to v0.7.x, see e.g. https://github.com/conda-forge/flax-feedstock/pull/29
We could either wait for the updates on flax-feedstock
, or use pip
to install flax in the ml-noteboook/environment.yml
file as a temporary measure.
OK, looks like flax=0.7.4
and flax=0.7.5
is available on conda-forge now. However, there's some dependency conflict with tensorflow
. I tried with this environment.yml
file:
name: pangeo
channels:
- conda-forge
- nodefaults
dependencies:
- flax>=0.7.0
- jax
- jupyterlab-nvdashboard
- keras-cv
- tensorflow>=2.13.1=*cuda112*
and got this traceback:
Locking dependencies for ['linux-64']...
INFO:conda_lock.conda_solver:linux-64 using specs ['flax >=0.7.0', 'jax', 'jupyterlab-nvdashboard', 'keras-cv', 'tensorflow >=2.13.1 *cuda112*', 'adlfs', 'argopy', 'awscli', 'boto3', 'bottleneck', 'cartopy', 'cdsapi', 'cfgrib', 'ciso', 'cmocean', 'dask-ml', 'datashader', 'descartes', 'earthaccess', 'eofs', 'erddapy', 'esmpy', 'fastjmd95', 'flox', 'fsspec', 'gcm_filters', 'gcsfs', 'gh', 'gh-scoped-creds', 'geocube', 'geopandas', 'geopy', 'geoviews-core', 'git-lfs', 'gsw', 'h5netcdf', 'h5py', 'holoviews', 'hvplot', 'intake', 'intake-esm', 'intake-geopandas', 'intake-stac', 'intake-xarray', 'ipykernel', 'ipyleaflet', 'ipytree', 'ipywidgets', 'jupyterlab-git', 'jupyter-panel-proxy', 'jupyter-resource-usage', 'kerchunk', 'lxml', 'lz4', 'matplotlib-base', 'metpy', 'nb_conda_kernels', 'nbstripout', 'nc-time-axis', 'netcdf4', 'numbagg', 'numcodecs', 'numpy', 'numpy_groupies', 'odc-stac', 'pandas', 'panel', 'parcels', 'param', 'pop-tools', 'pyarrow', 'pycamhd', 'pydap', 'pystac', 'pystac-client', 'python-blosc', 'python-gist', 'python-graphviz', 'rasterio', 'rechunker', 'rio-cogeo', 'rioxarray', 's3fs', 'satpy', 'scikit-image', 'scikit-learn', 'scipy', 'seaborn', 'sparse', 'snakeviz', 'stackstac', 'tiledb-py', 'timezonefinder', 'xarray', 'xarrayutils', 'xarray-datatree', 'xarray_leaflet', 'xarray-spatial', 'xbatcher', 'xcape', 'xclim', 'xesmf', 'xgboost', 'xgcm', 'xhistogram', 'xmip', 'xmitgcm', 'xpublish', 'xrft', 'xskillscore', 'zarr', 'python 3.11.*', 'pangeo-notebook 2023.11.11.*', 'pip']
Failed to parse json, Expecting value: line 1 column 1 (char 0)
Could not lock the environment for platform linux-64
Could not solve for environment specs
The following packages are incompatible
├─ flax >=0.7.0 is installable and it requires
│ └─ tensorstore with the potential options
│ ├─ tensorstore [0.1.44|0.1.46|0.1.47|0.1.48] would require
│ │ └─ libprotobuf [>=4.24.3,<4.24.4.0a0 |>=4.24.4,<4.24.5.0a0 ], which requires
│ │ └─ libabseil >=20230802.1,<20230803.0a0 , which can be installed;
│ └─ tensorstore 0.1.44 would require
│ └─ libprotobuf >=4.23.4,<4.23.5.0a0 with the potential options
│ ├─ libprotobuf 4.23.4, which can be installed;
│ └─ libprotobuf 4.23.4 would require
│ └─ libabseil >=20230802.0,<20230803.0a0 , which can be installed;
└─ tensorflow >=2.13.1 *cuda112* is uninstallable because it requires
└─ tensorflow-base [2.13.1 cuda112py310hbb601f2_1|2.13.1 cuda112py311h8bdbb6c_1|2.13.1 cuda112py38h79651c7_1|2.13.1 cuda112py39h85a252b_1], which requires
├─ libabseil >=20230125.3,<20230126.0a0 , which conflicts with any installable versions previously reported;
└─ libgrpc >=1.54.3,<1.55.0a0 , which requires
└─ libprotobuf >=3.21.12,<3.22.0a0 , which conflicts with any installable versions previously reported.
It seems like flax
-> tensorstore
is depending on libprotobuf~=4.24
, but tensorflow-base
-> libgrpc
is depending on libprotobuf~=3.21
. Seems like a longstanding issue at https://github.com/conda-forge/tensorflow-feedstock/issues/288, and the migration at ~either https://github.com/conda-forge/tensorflow-feedstock/pull/347 or~ https://github.com/conda-forge/tensorflow-feedstock/pull/342 might help.
Edit: ~Might need to wait for tensorflow 2.15.0 at https://github.com/conda-forge/tensorflow-feedstock/pull/353?~ Nope, conda-forge's tensorflow=2.15.0
doesn't help. They still need to handle the libprotobuf issue, wait for the bot re-run after https://github.com/conda-forge/tensorflow-feedstock/pull/359. ~See https://github.com/conda-forge/tensorflow-feedstock/pull/361 :pray:~ (PR was merged, but package later marked as broken according to https://github.com/conda-forge/tensorflow-feedstock/pull/367#issuecomment-1890757429 :sweat_smile:).
Looks like we might need to upgrade from CUDA 11.8 to 12 to get a newer version of tensorflow=2.15.0
from conda-forge with libprotobuf~=4.24
that works with flax>=0.7.4
, see https://github.com/conda-forge/tensorflow-feedstock/pull/367#issuecomment-1890784430.
@yuvipanda and @jbusecke - just getting your attention here, since my hacky workflows seems to have stopped working today.
Recently to work with jax and appropriate version of flax, I would do 2 steps on the leap hub:
mamba install cuda-nvcc==11.6.* -c nvidia
pip install flax==0.6.10
However, today morning this leads to a new error. After following these steps, I get the error:
RuntimeError: jaxlib is version 0.4.12, but this version of jax requires version >= 0.4.19.
when I try to do an import jax
.
Hey @dhruvbalwada, I suspect this is due to the recent update of the pangeo-docker-image on the LEAP hub.
To unblock you for now, I recommend you manually run from an older image (the LEAP docs provide instructions).
But this does not change the core problem here I think. Anything I could help/test to contribute here @weiji14?
Right, looks like we'll need to expedite the upgrade to CUDA 12 then as mentioned at https://github.com/pangeo-data/pangeo-docker-images/issues/489#issuecomment-1911393877. Let me open a PR for that (got some free time today), and then we'll be able to upgrade to newer tenforflow/flax versions.
Ok, not as simple as I thought. I tried running conda-lock
to create a lockfile with a newer version of tensorflow=2.15.0
and flax>=0.8.0
with CUDA 12.0, but it errors with:
The following packages are incompatible
├─ flax >=0.8.0 is installable and it requires
│ └─ jax >=0.4.11 with the potential options
│ ├─ jax 0.4.11 would require
│ │ └─ jaxlib >=0.4.7 with the potential options
│ │ ├─ jaxlib [0.4.10|0.4.11|0.4.12|0.4.14|0.4.9] would require
│ │ │ └─ libgrpc >=1.54.2,<1.55.0a0 , which can be installed;
│ │ ├─ jaxlib 0.4.12 would require
│ │ │ └─ libgrpc >=1.56.0,<1.57.0a0 , which can be installed;
│ │ ├─ jaxlib 0.4.14 would require
│ │ │ └─ libgrpc >=1.56.2,<1.57.0a0 , which can be installed;
│ │ ├─ jaxlib [0.4.14|0.4.18|0.4.19] would require
│ │ │ └─ libgrpc >=1.58.1,<1.59.0a0 , which can be installed;
│ │ ├─ jaxlib 0.4.14 would require
│ │ │ └─ libgrpc >=1.57.0,<1.58.0a0 , which can be installed;
│ │ ├─ jaxlib [0.4.20|0.4.23] would require
│ │ │ └─ libgrpc >=1.58.2,<1.59.0a0 , which can be installed;
│ │ ├─ jaxlib 0.4.7 would require
│ │ │ └─ libgrpc >=1.52.1,<1.53.0a0 , which can be installed;
│ │ └─ jaxlib 0.4.7 would require
│ │ └─ libgrpc >=1.54.0,<1.55.0a0 , which can be installed;
│ ├─ jax [0.4.12|0.4.13|0.4.14] would require
│ │ └─ jaxlib >=0.4.11 , which can be installed (as previously explained);
│ ├─ jax [0.4.16|0.4.17|0.4.19|0.4.20] would require
│ │ └─ jaxlib >=0.4.14 , which can be installed (as previously explained);
│ └─ jax [0.4.21|0.4.23] would require
│ └─ jaxlib >=0.4.19 , which can be installed (as previously explained);
└─ tensorflow >=2.15.0 *cuda120* is not installable because it requires
└─ tensorflow-base [2.15.0 cuda120py310heceb7ac_2|2.15.0 cuda120py310heceb7ac_3|...|2.15.0 cuda120py39hf42b710_3], which requires
└─ libgrpc >=1.59.3,<1.60.0a0 , which conflicts with any installable versions previously reported.
It looks like we'll need to wait for jaxlib
to support CUDA 12 (https://github.com/conda-forge/jaxlib-feedstock/issues/223, https://github.com/conda-forge/jaxlib-feedstock/pull/218), and also be rebuilt with libprotobuf 4.24 (https://github.com/conda-forge/jaxlib-feedstock/pull/221).
Do you think for now a slightly older version may be enough? maybe flax>=0.7?
Nope, flax>=0.7.0
doesn't work either
├─ flax >=0.7.0 is installable and it requires
│ └─ jax >=0.4.11 with the potential options
│ ├─ jax 0.4.11 would require
│ │ └─ jaxlib >=0.4.7 with the potential options
│ │ ├─ jaxlib [0.4.10|0.4.11|0.4.12|0.4.14|0.4.9] would require
│ │ │ └─ libgrpc >=1.54.2,<1.55.0a0 , which can be installed;
│ │ ├─ jaxlib 0.4.12 would require
│ │ │ └─ libgrpc >=1.56.0,<1.57.0a0 , which can be installed;
│ │ ├─ jaxlib 0.4.14 would require
│ │ │ └─ libgrpc >=1.56.2,<1.57.0a0 , which can be installed;
│ │ ├─ jaxlib [0.4.14|0.4.18|0.4.19] would require
│ │ │ └─ libgrpc >=1.58.1,<1.59.0a0 , which can be installed;
│ │ ├─ jaxlib 0.4.14 would require
│ │ │ └─ libgrpc >=1.57.0,<1.58.0a0 , which can be installed;
│ │ ├─ jaxlib [0.4.20|0.4.23] would require
│ │ │ └─ libgrpc >=1.58.2,<1.59.0a0 , which can be installed;
│ │ ├─ jaxlib 0.4.7 would require
│ │ │ └─ libgrpc >=1.52.1,<1.53.0a0 , which can be installed;
│ │ └─ jaxlib 0.4.7 would require
│ │ └─ libgrpc >=1.54.0,<1.55.0a0 , which can be installed;
│ ├─ jax [0.4.12|0.4.13|0.4.14] would require
│ │ └─ jaxlib >=0.4.11 , which can be installed (as previously explained);
│ ├─ jax [0.4.16|0.4.17|0.4.19|0.4.20] would require
│ │ └─ jaxlib >=0.4.14 , which can be installed (as previously explained);
│ └─ jax [0.4.21|0.4.23] would require
│ └─ jaxlib >=0.4.19 , which can be installed (as previously explained);
└─ tensorflow >=2.15.0 *cuda120* is not installable because it requires
└─ tensorflow-base [2.15.0 cuda120py310heceb7ac_2|2.15.0 cuda120py310heceb7ac_3|...|2.15.0 cuda120py39hf42b710_3], which requires
└─ libgrpc >=1.59.3,<1.60.0a0 , which conflicts with any installable versions previously reported.
I've also tried older version combinations with CUDA 11.2 and tensorflow 2.13.x last year (see all my crossed out links in https://github.com/pangeo-data/pangeo-docker-images/issues/489#issuecomment-1807319576), but they all don't work. We really need to get all the tensorflow/jax libraries to align on the correct version of libprotobuf in conda-forge.
The new hack that is working is :
pip install 'flax==0.7.2' 'jax<=0.4.13' 'ml_dtypes==0.2.0'
mamba install cuda-nvcc==11.6.* -c nvidia
Hopefully alignment will come in near future.
Describe the bug Current version of flax (0.6.1) on the image does not work properly with the jax version (0.4.13). T
To Reproduce Issue can be reproduced by doing
from flax.training import checkpoints
, which will give the errorModuleNotFoundError: No module named 'jax.experimental.global_device_array'
. This has been discussed in https://github.com/google/flax/issues/3087.Expected behavior Flax should be importable.
Infrastructure (Where you are running this image):
Solution At the moment I solve this doing
pip install flax==0.6.10