Closed vwxyzjn closed 1 year ago
Reproduced with poetry add jax=="0.4.8" flax=="0.6.8" optax=="0.1.4" chex=="0.1.5" orbax=="0.1.4". You can then do poetry run pip install --upgrade "jax[cuda11_cudnn82]==0.4.8" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
poetry add jax=="0.4.8" flax=="0.6.8" optax=="0.1.4" chex=="0.1.5" orbax=="0.1.4"
Closed in favor of #13
Reproduced with
poetry add jax=="0.4.8" flax=="0.6.8" optax=="0.1.4" chex=="0.1.5" orbax=="0.1.4"
. You can then do poetry run pip install --upgrade "jax[cuda11_cudnn82]==0.4.8" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html