FLAIROx / JaxMARL

Multi-Agent Reinforcement Learning with JAX
Apache License 2.0
393 stars 68 forks source link

Dimensions of `world_state` and `reward` do not match #85

Closed victor-qin closed 4 months ago

victor-qin commented 4 months ago

https://github.com/FLAIROx/JaxMARL/blob/d941770bb1bf6945412c70c59c509226d9d39628/baselines/MAPPO/mappo_rnn_mpe.py#L292C1-L295C33

Each agent gets (more or less) the same reward at each step of the environment. In the output of _env_state, it looks like the batched rewards are shaped as [(agent1, env1), (agent1, env2).....(agent2, env1), (agent2, env2)......], but the batched world_states are shaped as [(agent1, env1), (agent2, env1), (agent3, env1), (agent1, env2) .....] ((agent1, env1) = (agent2, env1) = (agent3, env1)).

This affects advantage calculation later, because last_val inherits the shape of world_states, and now you're matching the wrong reward to the values.

To reproduce:

Solution: should be just changing the order option for jnp.reshape

amacrutherford commented 4 months ago

cheers for spotting this! Fixed with #87