FLAIROx / JaxMARL

Multi-Agent Reinforcement Learning with JAX
Apache License 2.0
393 stars 68 forks source link

Jax version #83

Closed alexunderch closed 2 months ago

alexunderch commented 4 months ago

Hey! Please can you loose jax version used in the library, it makes cuda unavailable, and conflicts with newer versions of the library. For minimal reproduction, try to install jaxmarl in colab:

!pip install jaxmarl 

and, then:

import jax
print(jax.devices())

I'd made it like jax>=0.4.17 to support backward compatibility.

amacrutherford commented 4 months ago

Hey! so it shouldn't make cuda unavailable, we assume jax is already installed, but I'll make this clearer in the ReadMe. We could change this but it works well with the Dockerfile currently. But we do agree on relaxing the requirement, we'll do this at the end of May once were done with a submission, for now we want to make sure all our results are collected on the same environment 😄

rubimat commented 4 months ago

I had a similar issue when running the baseline algorithm ippo_rnn_smax.py. I got the errors:

AttributeError: module 'scipy.linalg' has no attribute 'tril' AttributeError: module 'scipy.linalg' has no attribute 'triu'

Which seems to come from an outdated version of jax. This was fixed in the latest release, but for me it messed up other things as well.

alexunderch commented 4 months ago

This error is not very much about jax version but about scipy. Check out that yours is <1.11.4 for jax<=0.4.25 at least. I think that the devs migrate between APIs right now.

Caffa commented 3 months ago

Here are the current 'cpu' install instructions that I found worked:

conda deactivate
conda remove -n JaxMARL --all 

# to check if you're running on an 86_64 (64-bit) Linux system:
uname -m
# should see x86_64

nvidia-smi 
# should see cuda version here
# driver version must be >= 525.60.13 for CUDA 12 on Linux.

# uses python 3.10
conda create -n JaxMARL python=3.10

conda activate JaxMARL

pip install jax[cpu]==0.4.17 -f https://storage.googleapis.com/jax-releases/jax_releases.html
pip install jaxlib==0.4.17 -f https://storage.googleapis.com/jax-releases/jax_releases.html

git clone https://github.com/FLAIROx/JaxMARL.git && cd JaxMARL
pip install -e .
export PYTHONPATH=./JaxMARL:$PYTHONPATH

 conda list jax
# packages in environment at /home/marmot/micromamba/envs/JaxMARL:
#
# Name                    Version                   Build  Channel
#jax                       0.4.17                   pypi_0    pypi
#jaxlib                    0.4.17                   pypi_0    pypi
#jaxmarl                   0.0.3                    pypi_0    pypi

Because there is no pip install jax[cuda12]==0.4.17 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html right?

alexunderch commented 3 months ago

Hey! Let's reiterate your question! 1) there is no jax version 0.4.17 for cuda12, you're right: can checkout clicking the link https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 2) jaxmarl as stated above suggests you have jax version of your choice already installed, so you can install jax firstly, and proceed with installation further, however, you might feel not okay about the version because you're for some time is cornered with jax==0.4.17, so 3) you can install jaxmarl from source (if you are willing to contribute) or with pip install jaxmarl if you're going to utilise the stable version, and then install jax version of your choice, for example:

pip install jax[cuda12]==0.4.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Or, for cpu

pip install jax==0.4.25 

Clear? However, remember that if you want to work with cuda11, the process might be a little bit trickier. Tag me, if I need elaborate.

alexunderch commented 3 months ago

I think can add some notes into the 'intallation section' and ease requirements to jax, to avoid future confusion

amacrutherford commented 2 months ago

@alexunderch see #97

alexunderch commented 2 months ago

@amacrutherford, took a look, well-placed changes + user specification in the dockerfile. The only thing to worry about might be that since jax>0.4.25 there could be changes in scipy/general jax api that can have issues in the future. So, maybe to have an upperbound for max current version you work with, or smth like this?

All good, thank you very much for this! Can close!

amacrutherford commented 2 months ago

yea good point, will do :)