FLAIROx / JaxMARL

Multi-Agent Reinforcement Learning with JAX
Apache License 2.0
393 stars 68 forks source link

IPPO inference + save animation #72

Closed satpreetsingh closed 5 months ago

satpreetsingh commented 5 months ago

Added an example of how to do inference with a trained IPPO agent

Code then saves the animation for 2 episodes

Addresses request made in https://github.com/FLAIROx/JaxMARL/issues/64

mttga commented 5 months ago

Hi @satpreetsingh, the code you're providing doesn't follow jax best practices. You can get multiple trajectories in a much more efficient way using jax.lax.scan and vmap (instad of nested for loops), and then postprocess the collected states for being visualized. I think I have a piece of code somewhere to do this.

satpreetsingh commented 5 months ago

Hi @mttga, thanks for responding! I was trying to produce the simplest example possible. I'm happy to update my code to follow the recommended style if you can point me to your code.

Also, may I know if the produced animations look good at your end? At my end, it seemed like they indicated that the learning process had not fully converged (at the config-specified 2e6 steps), OR that there was some issue in transferring the converged weights from train to test.

amacrutherford commented 5 months ago

Hey @satpreetsingh ! Not sure if this file is the best place for this as we would like to keep the training files as lightweight as possible to aid understanding of the core code. Instead, this should likely go in the walkthrough notebook under the last part which features Overcooked, I'll put it on our internal TODO to add inference to the notebook and put a note of this in the README :smile:

satpreetsingh commented 5 months ago

Thanks for responding @amacrutherford. Can you also comment on the fact that the the post-training (converged) agents don’t perform too well in the produced animations (see attached examples).

The training seems to converge

$ python ippo_rnn_mpe.py 
{'returns': -102.80365, 'env_step': 0}
{'returns': -100.726265, 'env_step': 3072}
{'returns': -104.75728, 'env_step': 6144}
{'returns': -94.75876, 'env_step': 9216}
{'returns': -89.8966, 'env_step': 12288}
{'returns': -94.21084, 'env_step': 15360}
{'returns': -92.89967, 'env_step': 18432}
{'returns': -93.211845, 'env_step': 21504}
...
{'returns': -67.70681, 'env_step': 1978368}
{'returns': -68.14425, 'env_step': 1981440}
{'returns': -68.150986, 'env_step': 1984512}
{'returns': -66.487785, 'env_step': 1987584}
{'returns': -65.92587, 'env_step': 1990656}
{'returns': -67.4093, 'env_step': 1993728}
{'returns': -66.18083, 'env_step': 1996800}

Saved: ippo_mpe_ep00.gif
Saved: ippo_mpe_ep01.gif

Do you prefer if I open an issue about this? ippo_mpe_ep01 ippo_mpe_ep00

Also for reference, here's what a good converged policy should look like: https://www.youtube.com/watch?v=QQ4dauqfmnU

amacrutherford commented 5 months ago

ah interesting, yeah could you chuck this in an issue and we'll take a look? thanks for raising

satpreetsingh commented 5 months ago

Done! https://github.com/FLAIROx/JaxMARL/issues/73

satpreetsingh commented 5 months ago

@mttga : Can you point me to the (inference) code you were referring to above?