Open glegarda opened 4 months ago
you need to upgrade to python>=3.9, python3.8 does not suport this kind of type annotation.
Thanks @arnauqb for stepping in and helping. And apologies @glegarda for not being more specific about Python and JAX versions. I will go back and annotate the versions of each requirement more rigorously once I get a chance.
Thank you both for your help! I had to do some tinkering, but eventually I got it the example working. First, I upgraded to Python 3.11. This got rid of the original error, but prompted some others due to further compatibility issues. In case this helps you annotate the required versions, @conorheins, these are the ones I had to install manually/reinstall:
jax 0.4.19
: the JAX version 0.4.28 installed by default throws an error because the 'KeyArray' attribute of 'jax.random' has been deprecated since version 0.4.16 and was removed in 0.4.24. Version 0.4.19 is the first version of JAX that is also compatible with the rest of the packages installed (the package jax-md in the requirements installs flax 0.8.3, which requires JAX version >= 0.4.19). Note that I had to reinstall JAX after installing the requirements, as installing jax-md forces the (re)installation of JAX 0.4.28.scipy 1.12.0
: the latest SciPy version (1.13.0) throws an error because the attribute 'tril' was removed from 'scipy.linalg'.PyQt5
: I had to also install this one in order to display the figure from matplotlibWith these modifications, I was able to run the example.
Thanks a lot for documenting this so thoroughly @glegarda. Good to know about the deprecation of the KeyArray attribute in newer versions of JAX. JAX's experimental development status means that these deprecations/lack of reverse-compatibility unfortunately spring up frustratingly often. So I'll either (A) freeze the requirements to an earlier version of jax (like 0.4.19) that is before 0.4.24 while still being new enough to be compatible with the remaining packages like jax-md, flax 0.8.3, etc), or (B) I'll just update the code to be consistent with latest versions of jax like 0.4.24 and greater.
Hello,
I followed the JAX set up instructions and tried to run the demo script, but I obtained the following error:
I am working on an Ubuntu 20.04.6 LTS x86_64 machine with an NVIDIA GeForce RTX 3060 and Python 3.8.10, and I tried both the GPU and CPU versions of JAX, but the error remains.
Any clue as to what might be going on? Perhaps some version compatibility issue?
Thanks!