Open ChristopherBottomsOMRF opened 4 months ago
Hi,
Try this:
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html pip install --upgrade dm-haiku
Sorry, actually, the issue remains. I'd run into another issue that preceded this one (i.e. I had to update the version for tf-nightly in speed_ppi.yml to 2.16.0).
I also try to run all vs all and get the same issue: AttributeError: 'Config' object has no attribute 'define_bool_state' I already tried: pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html pip install --upgrade dm-haiku
But unfortunately, the issue remains. Does someone know how to solve this?
I also try to run all vs all and get the same issue: AttributeError: 'Config' object has no attribute 'define_bool_state' I already tried: pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html pip install --upgrade dm-haiku
But unfortunately, the issue remains. Does someone know how to solve this?
Downgrading Jax helped me (jax 0.4.23, jaxlib 0.4.23+cuda12.cudnn89).
You can also try: pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
I get this runtime error while running all vs all.
I assume that I'm using the wrong version of Jax or something?