instadeepai / Mava

🦁 A research-friendly codebase for fast experimentation of multi-agent reinforcement learning in JAX
Apache License 2.0
709 stars 83 forks source link

Feat/remove agent vmapping #1038

Closed RuanJohn closed 7 months ago

RuanJohn commented 7 months ago

What?

Remove vmaping manually over agents and instead rely on Flax's autobatching since parameters are shared.

Why?

Since we return distributions which can be sampled from our policies, we ran into issues with both Distrax and Tensorflow probability where the creation of distributions for continuous action space policies cannot be vmaped. To keep things consistent across system we opted to remove vmaping over agents.