Open UsaidPro opened 6 days ago
Testing the MaBrax environment with python mava/systems/sac/ff_isac.py
, I get this different error:
/home/usaidm/Documents/ExternalLibs/mava/mava/systems/sac/ff_isac.py:106: DeprecationWarning: jax.tree_map is deprecated: use jax.tree.map (jax v0.4.25 or newer) or jax.tree_util.tree_map (any JAX version).
obs_single_batched = jax.tree_map(lambda x: x[0][jnp.newaxis, ...], obs)
Error executing job with overrides: []
Error in call to target 'mava.networks.DiscreteActionHead':
TypeError("__init__() got an unexpected keyword argument 'independent_std'")
full_key: network.action_head
Maybe the failure is due to some Python version conflict? But the README.md
says Mava is tested for Python 3.9 and I am using 3.9.10.
Hi @UsaidPro, thank you for raising this issue. We will investigate this further over the coming days. For now, what solved the problem for me is to downgrade Flax and to use jax[cuda12] with jax==0.4.26 instead of jax[cuda12_local]. In your virtualenv could you please run:
pip install flax==0.8.1
pip install jax[cuda12]==0.4.26
And then test with python mava/systems/ppo/ff_ippo.py env=rware env/scenario=tiny-2ag
Then, for the second problem you are facing, this is because MaBrax is a continuous action space environment and the network config is using a discrete action head. To fix this either edit action_head._target_
in the network config in mava/configs/network/mlp.yaml
to be
action_head:
_target_: mava.networks.ContinuousActionHead # [DiscreteActionHead, ContinuousActionHead]
Or overwrite it from the terminal by running python mava/systems/sac/ff_isac.py network.action_head._target_=mava.networks.ContinuousActionHead
.
Please let me know how this goes.
Hi @UsaidPro I've only ever seen this error when trying to use python 3.11 and 3.12 so it's interesting you're seeing this on 3.9. As Ruan mentioned see if not using the local CUDA install works for you?
Describe the bug
Hello! I am trying to get Mava working to test out the library. Following the
README
, I created a 3.9.10 virtualenv and installedjax[cuda12_local]
(I have existing CUDA 12.3 install) andpip install -e .
in themava/
Github root directory.When I try to run
python mava/systems/ppo/ff_ippo.py
, I get the following failure:To Reproduce
Steps to reproduce the behavior:
pip install "jax[cuda12_local]"
DETAILED_INSTALL.md
python mava/systems/ppo/ff_ippo.py
inmava/
Github root directoryExpected behavior
Successfully start Mava and train on
RobotWarehouse
env.Context (Environment)
pip freeze > requirements.txt
- Github gist link hereAdditional context
N/A
Possible Solution
Not sure since I am new to package. I will test other environments to see if they also have the same issue.