conorheins / collective_motion_actinf

Code for simulating collective motion from groups of continuous-time and -space active inference agents.
36 stars 6 forks source link

Running JAX demo script #1

Open glegarda opened 4 months ago

glegarda commented 4 months ago

Hello,

I followed the JAX set up instructions and tried to run the demo script, but I obtained the following error:

Traceback (most recent call last):
  File "src/demo_nolearning.py", line 8, in <module>
    from utils import initialize_meta_params, get_default_inits, run_single_simulation, str2bool
  File "/home/guillermo/Code/git-projects/collective_motion_actinf/jax/src/utils.py", line 15, in <module>
    from genprocess import get_observations, get_observations_special, advance_positions, init_gen_process, compute_Dgroup_and_rankings_t, compute_Dgroup_and_rankings_vmapped, compute_turning_magnitudes, compute_integrated_change_magnitude
  File "/home/guillermo/Code/git-projects/collective_motion_actinf/jax/src/genprocess/__init__.py", line 1, in <module>
    from .geometry import *
  File "/home/guillermo/Code/git-projects/collective_motion_actinf/jax/src/genprocess/geometry.py", line 3, in <module>
    from jax_md import space
  File "/home/guillermo/.venv/jax/lib/python3.8/site-packages/jax_md/__init__.py", line 16, in <module>
    from jax_md import energy
  File "/home/guillermo/.venv/jax/lib/python3.8/site-packages/jax_md/energy.py", line 28, in <module>
    import haiku as hk
  File "/home/guillermo/.venv/jax/lib/python3.8/site-packages/haiku/__init__.py", line 19, in <module>
    from haiku import data_structures
  File "/home/guillermo/.venv/jax/lib/python3.8/site-packages/haiku/data_structures.py", line 18, in <module>
    from haiku._src.data_structures import to_haiku_dict
  File "/home/guillermo/.venv/jax/lib/python3.8/site-packages/haiku/_src/data_structures.py", line 30, in <module>
    from haiku._src import utils
  File "/home/guillermo/.venv/jax/lib/python3.8/site-packages/haiku/_src/utils.py", line 42, in <module>
    def auto_repr(cls: type[Any], *args, **kwargs) -> str:
TypeError: 'type' object is not subscriptable

I am working on an Ubuntu 20.04.6 LTS x86_64 machine with an NVIDIA GeForce RTX 3060 and Python 3.8.10, and I tried both the GPU and CPU versions of JAX, but the error remains.

Any clue as to what might be going on? Perhaps some version compatibility issue?

Thanks!

arnauqb commented 4 months ago

you need to upgrade to python>=3.9, python3.8 does not suport this kind of type annotation.

conorheins commented 3 months ago

Thanks @arnauqb for stepping in and helping. And apologies @glegarda for not being more specific about Python and JAX versions. I will go back and annotate the versions of each requirement more rigorously once I get a chance.

glegarda commented 3 months ago

Thank you both for your help! I had to do some tinkering, but eventually I got it the example working. First, I upgraded to Python 3.11. This got rid of the original error, but prompted some others due to further compatibility issues. In case this helps you annotate the required versions, @conorheins, these are the ones I had to install manually/reinstall:

With these modifications, I was able to run the example.

conorheins commented 3 months ago

Thanks a lot for documenting this so thoroughly @glegarda. Good to know about the deprecation of the KeyArray attribute in newer versions of JAX. JAX's experimental development status means that these deprecations/lack of reverse-compatibility unfortunately spring up frustratingly often. So I'll either (A) freeze the requirements to an earlier version of jax (like 0.4.19) that is before 0.4.24 while still being new enough to be compatible with the remaining packages like jax-md, flax 0.8.3, etc), or (B) I'll just update the code to be consistent with latest versions of jax like 0.4.24 and greater.