instadeepai / Mava

🦁 A research-friendly codebase for fast experimentation of multi-agent reinforcement learning in JAX
Apache License 2.0
709 stars 83 forks source link

fix: mabrax wrapper #1031

Closed sash-a closed 7 months ago

sash-a commented 7 months ago

What?

MaBrax wrapper had some jitting issues, so using @cached_property in order to pre-compute the values.

Also I made a bad comment in the previous jaxmarl PR saying we should rename states to brax state, when it could be any inner environment, so I've renamed that.

liamclarkza commented 7 months ago

If I try to run an experiment using the mabrax environment using any of the available systems, I get the error below. It looks like the issue is with the various system implementations (ff_ippo, ffmappo, rec ...) expecting the env.action_spec() to be returning either a DiscreteArray or a MultiDiscreteArray.

For example when running python mava/systems/ff_ippo.py env=mabrax:

Error executing job with overrides: ['env=mabrax']
Traceback (most recent call last):
  File "mava/systems/ff_ippo.py", line 599, in hydra_entry_point
    eval_performance = run_experiment(cfg)
  File "mava/systems/ff_ippo.py", line 472, in run_experiment
    learn, actor_network, learner_state = learner_setup(
  File "mava/systems/ff_ippo.py", line 349, in learner_setup
    num_actions = int(env.action_spec().num_values[0])
AttributeError: 'BoundedArray' object has no attribute 'num_values'

From what I can see, we either need to change the BoundedArray class or update the systems to handle BoundedArrays for non-discrete systems.

WiemKhlifi commented 7 months ago

@liamclarkza This environment supports only continuous actions and the wrapper is just an early step before merging the rest of the continuous systems. If you want to test this please check this fresh branch (feat/check_global_state) we used it to run a sweep using ff_mappo_cont system 🙌

Update: here we use num_actions=env.action_spec().shape[1], for consistency later we may add an attribute that gives the n_actions without accessing the spec.