Open thomashirtz opened 4 months ago
Hmm, let me look into this. I unfortunately dont have access to a GPU machine currently so itll be hard for me to test this however regardless this reminds me to raise the jax version in the requirements file. Just make sure that the image you are pulling and the jax version has the same cuda and cudnn version and that they are aligned.
@thomashirtz Did you ever figure out the issue?
No, unfortunately I didn't, because I don't have too much time debugging this, I stopped using docker and switch to venv
Describe the bug
Hello!
When making the Dockerfile, I get the error
Cannot import name 'linear_util' from 'jax'
when running examples. This seems to be due to the incompatibility of flax with jax. https://stackoverflow.com/questions/78210393/cannot-import-name-linear-util-from-jax (I do get access to my GPU 2070MaxQ with those settings)I therefore tried to install the version 4.24 by changing requirements.txt from
jax>=0.4.10
tojax>=0.4.24
and the Dockerfile line 36 to :however I get the error, not being able to use my gpu anymore :
Do you have any idea how to solve that ?
Full traceback:
To Reproduce
Steps to reproduce the behavior:
Possible Solution
Change version of flax and jax/jaxlib in the requirements.txt and the Dockerfile
Context (Environment)
Linux 24.04 with docker. This is the pip freeze if I run the Docker with the current setting of the repo: