conda create --name desc-env 'python==3.11'
conda activate desc-env
pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# remove the jax lines from requirements.txt, as we already have installed them above
sed -i '/jax/d' ./requirements.txt
# then install as usual
pip install --editable .
# optionally install developer requirements (if you want to run tests)
pip install -r devtools/dev-requirements.txt
These seem to give a driver warning though, and also would be ideal to use the cuda that is already installed locally instead of getting it through this pip call
These should work, based off here: https://github.com/PrincetonUniversity/intro_ml_libs/tree/master/jax#pip-installation