instadeepai / Mava

🦁 A research-friendly codebase for fast experimentation of multi-agent reinforcement learning in JAX
Apache License 2.0
671 stars 83 forks source link

[BUG] LLVM ERROR: mma16816 data type not supported #1082

Open UsaidPro opened 6 days ago

UsaidPro commented 6 days ago

Describe the bug

Hello! I am trying to get Mava working to test out the library. Following the README, I created a 3.9.10 virtualenv and installed jax[cuda12_local] (I have existing CUDA 12.3 install) and pip install -e . in the mava/ Github root directory.

When I try to run python mava/systems/ppo/ff_ippo.py, I get the following failure:

/home/usaidm/Documents/ExternalLibs/mava/mava/systems/ppo/ff_ippo.py:410: DeprecationWarning: jax.tree_map is deprecated: use jax.tree.map (jax v0.4.25 or newer) or jax.tree_util.tree_map (any JAX version).
  env_states = jax.tree_map(reshape_states, env_states)
/home/usaidm/Documents/ExternalLibs/mava/mava/systems/ppo/ff_ippo.py:411: DeprecationWarning: jax.tree_map is deprecated: use jax.tree.map (jax v0.4.25 or newer) or jax.tree_util.tree_map (any JAX version).
  timesteps = jax.tree_map(reshape_states, timesteps)
/home/usaidm/Documents/ExternalLibs/mava/mava/systems/ppo/ff_ippo.py:431: DeprecationWarning: jax.tree_map is deprecated: use jax.tree.map (jax v0.4.25 or newer) or jax.tree_util.tree_map (any JAX version).
  replicate_learner = jax.tree_map(broadcast, replicate_learner)
[... (default config is printed here)]

LLVM ERROR: mma16816 data type not supported
Aborted (core dumped)

To Reproduce

Steps to reproduce the behavior:

  1. pip install "jax[cuda12_local]"
  2. Follow the Python virtual environment-based installation guide in DETAILED_INSTALL.md
  3. Run python mava/systems/ppo/ff_ippo.py in mava/ Github root directory

Expected behavior

Successfully start Mava and train on RobotWarehouse env.

Context (Environment)

Additional context

N/A

Possible Solution

Not sure since I am new to package. I will test other environments to see if they also have the same issue.

UsaidPro commented 5 days ago

Testing the MaBrax environment with python mava/systems/sac/ff_isac.py, I get this different error:

/home/usaidm/Documents/ExternalLibs/mava/mava/systems/sac/ff_isac.py:106: DeprecationWarning: jax.tree_map is deprecated: use jax.tree.map (jax v0.4.25 or newer) or jax.tree_util.tree_map (any JAX version).
  obs_single_batched = jax.tree_map(lambda x: x[0][jnp.newaxis, ...], obs)
Error executing job with overrides: []
Error in call to target 'mava.networks.DiscreteActionHead':
TypeError("__init__() got an unexpected keyword argument 'independent_std'")
full_key: network.action_head

Maybe the failure is due to some Python version conflict? But the README.md says Mava is tested for Python 3.9 and I am using 3.9.10.

RuanJohn commented 5 days ago

Hi @UsaidPro, thank you for raising this issue. We will investigate this further over the coming days. For now, what solved the problem for me is to downgrade Flax and to use jax[cuda12] with jax==0.4.26 instead of jax[cuda12_local]. In your virtualenv could you please run:

pip install flax==0.8.1
pip install jax[cuda12]==0.4.26

And then test with python mava/systems/ppo/ff_ippo.py env=rware env/scenario=tiny-2ag

Then, for the second problem you are facing, this is because MaBrax is a continuous action space environment and the network config is using a discrete action head. To fix this either edit action_head._target_ in the network config in mava/configs/network/mlp.yaml to be

action_head:
  _target_: mava.networks.ContinuousActionHead # [DiscreteActionHead, ContinuousActionHead]

Or overwrite it from the terminal by running python mava/systems/sac/ff_isac.py network.action_head._target_=mava.networks.ContinuousActionHead.

Please let me know how this goes.

sash-a commented 3 days ago

Hi @UsaidPro I've only ever seen this error when trying to use python 3.11 and 3.12 so it's interesting you're seeing this on 3.9. As Ruan mentioned see if not using the local CUDA install works for you?