machine-discovery / deer

Parallelizing non-linear sequential models over the sequence length
BSD 3-Clause "New" or "Revised" License
40 stars 1 forks source link

Package installation problem #5

Closed yhl48 closed 11 months ago

yhl48 commented 11 months ago

Related to PR #4. The Jax and CUDNN version don't work in my machine.

Replacing

pip install --upgrade jax==0.4.11 jaxlib==0.4.11+cuda12.cudnn88 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install --upgrade -e .

with

pip install --upgrade -e .
pip install --upgrade jax==0.4.11 jaxlib==0.4.11+cuda11.cudnn86 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install nvidia-cudnn-cu11==8.6.0.163

works for me.

mfkasim1 commented 11 months ago

Probably because you're using cuda11. If you follow the order (i.e., making pip install --upgrade -e . last), would it still work?

yhl48 commented 11 months ago

it didn't work, I think might be because of pytorch_lightning

mfkasim1 commented 11 months ago

Solved by f0b06b6