Open haukekoehn opened 7 months ago
@haukekoehn thanks for the hint! I got it working (on Ubuntu 22.04) via: pip3 install jax-0.3.25.tar.gz pip3 install -U --no-deps git+https://github.com/CosmoStatGW/gwfast ... Successfully installed gwfast-1.1.1
From running a gwfast example, I got the error: RuntimeError: jaxlib version 0.4.25 is newer than and incompatible with jax version 0.3.25. Please update your jax and/or jaxlib packages.
So I updated the jax package: pip3 install jax[cpu]==0.4.25
And then finally had to manually change deprecations, e.g. https://github.com/google/jax/issues/17244
Now it works!
A workaround for this is using conda. It lets you install the older versions. conda install jaxlib==0.3.25
will directly install that version in your conda environment and then you can do pip install gwfast
.
Since the installation of gwfast apparently requires jax[cpu]<0.4.0, but jaxlib<0.4.0 is deprecated from PyPi (see the discussion here https://github.com/google/jax/issues/18368), the naive installation will fail.
A possible solution is to manually download the .whl files for version 0.3.25 or lower from https://storage.googleapis.com/jax-releases/jax_releases.html.