instadeepai / Mava

🦁 A research-friendly codebase for fast experimentation of multi-agent reinforcement learning in JAX
Apache License 2.0
692 stars 83 forks source link

[FEATURE] GNN support for the MARL frameworks #1091

Open Jaroan opened 1 month ago

Jaroan commented 1 month ago

Please describe the purpose of the feature. Is it related to a problem?

I am inquiring about possibly integrating JAX-based Graph Neural Networks (GNNs) into MAVA for use in MARL. Many MARL algorithms have higher success rates due to the effective utilization of GNNs for representing neighbors of each entity, especially in environments where the relationships between agents are complex and dynamic.

Describe the solution you'd like

I am trying to incorporate a GNN such as a Graph TransformerConv and Graph MessagePassing into MAPPO similar to this method: InforMARL. However, I am running into issues since I am new to JAX.

Describe alternatives you've considered

I have found standalone JAX implementations of Graph Attention networks and GCN but have yet to see any incorporated into the MARL actor-critic setup.

How do we know when the implementation of this feature is complete?

Checklist:

I believe these additions would be of great interest to the MAVA user community and would open up new application avenues for research.

Thank you for considering this request.

sash-a commented 1 month ago

Hi @Jaroan thanks for the issue!

This should definitely be possible and is an interesting line of work, however it's not on our roadmap in the near future. But this is exactly why we created mava, it should be relatively easy to fork Mava and start hacking MAPPO to integrate GNNs. If it works we're more than happy to accept pull requests.

If there's anything you don't understand with how Mava works we're happy to explain. For adding GNNs to MAPPO I'd start by replacing this block with a GNN of the right size and see what happens. (Ignore all the hydra stuff in that block, it's just so we can instantiate a network through config, and just write your network in the MAPPO file)