FLAIROx / JaxMARL

Multi-Agent Reinforcement Learning with JAX
Apache License 2.0
393 stars 68 forks source link

Fixing incomplete state update issue in the step function of STORM environemnt #60

Closed hy-kiera closed 5 months ago

hy-kiera commented 6 months ago

The _step() was returning a state where some fields were not updated, leading to inaccuracies in the calculation of done. Specifically, while agent_freezes, agent_positions, and other fields were updated, inner_t remained unchanged in the returned state.

To fix this, I've replaced state_nxt that all fields are updated, ensuring accurate done calculations.

Please review, and feel free to provide any additional feedback or suggestions for further refinement.

hy-kiera commented 6 months ago

Hi. @amacrutherford @Aidandos Could you please review my PR? Thanks.

amacrutherford commented 6 months ago

Hey! Apologies for the delay on this, @Aidandos had a paper deadline but should get to it over the next few days :smile:

Aidandos commented 5 months ago

Hi @hy-kiera . Apologies for the delayed response and thanks for pointing out this issue. Your fix does fix the inner_t counting, but not the outer_t counting. I edited your PR to fix both.

The counting was broken both in storm_2p.py and storm_env.py.

I added some print statements to the tutorials, showing that inner_t and outer_t are incrementing correctly. It also prints the done logic. To run the tutorials, simply run python3 jaxmarl/tutorials/storm_2p_introduction.py or python3 jaxmarl/tutorials/storm_introduction.py