facebookresearch / BenchMARL

A collection of MARL benchmarks based on TorchRL
https://benchmarl.readthedocs.io/
MIT License
263 stars 38 forks source link

How to integrate algo `Multi-Agent Transformer` in BenchMARL #117

Closed fmxFranky closed 2 months ago

fmxFranky commented 3 months ago

Hi there,

First off, I want to say that BenchMARL looks like a fantastic resource for the MARL community! I've been following the project and I'm excited about the potential for benchmarking various algorithms.

I recently came across the Multi-Agent Transformer (MAT) by PKU-MARL, which is showcased in their GitHub here and the accompanying paper on arXiv. MAT introduces an innovative approach to MARL with its centralized joint-observation-action Transformer structure, and I believe it could be a valuable addition to BenchMARL.

arch

I'm curious about how I might go about integrating MAT into BenchMARL. Could you provide some guidance or suggestions on this process? I'm eager to contribute to the project and explore the synergy between MAT and BenchMARL.

Thanks for your great work on this, and I'm looking forward to any advice you can offer!

matteobettini commented 2 months ago

Hey!

Thanks for the comment!

So, to understand better how we can frame this I have a few points:

fmxFranky commented 2 months ago

Dear Developer,

The algorithm I cited is actually a centralized multi-agent reinforcement learning algorithm independent of the time dimension. At any given timestep, it takes in the joint observations of all agents and then outputs their corresponding actions via a Transformer-based model. During the training phase, its decoder conducts parallel predicting; however, during the execution phase, the decoder part outputs agents autoregressively. My current question is whether the framework supports the implementation of such a cross-agent centralized Transformer-based network structure.

Thank you!

matteobettini commented 2 months ago

Ah I see! Even better then. It is possible to implement it as a model (to be used as a policy) and train it with any algorithm in the library. The only thing that is strange is that it would be a centralized policy (as it needs access to all other agents' info).

You can also create a a new algorithm representing the one described in the paper and train it with that.

When you say that the actions are computed parallelly during training, what do you mean? You still need to access the actions of the prior agents to compute the next agent action right? is it just that they are outputed at the same time? cause this would also need to be the case for execution (ie. we run the model autoregressively for all agents within one forward pass for each step)

If you want to look at a model that has similar behavior to understand how this would work in benchmarl, you can look at the GNN https://github.com/facebookresearch/BenchMARL/blob/main/benchmarl/models/gnn.py . In particular, if you use a gnn with GATv2 and full topology in benchmarl you would obtain a fully connected attention layer over the agents (similar to a transformer). The difference between that and this one seems to be the autoregressive nature of action prediction, where agents are given an order and later agents have access to earlier agents info.

Regarding this ordering, I guess my main question is related to how we can choose the ordering? what makes us establish that an agent is first and has access to less info than the others?

Also, one other concern I have is regarding this centralization. If the transformer accesses all agents info, what makes this a multiagent problem and not simply a single agent problem with a factored action and observation space? When I designed benchmarl, I did not envision the possibility of centralized policies as in my idea they would make the problem not MARL. In the arxiv the authors propose a Decentralized version of the model (which still relies on fully observability but removes the autoregressive actions). This would be equivalent to the fully connected GNN with GATv2 i mentioned earlier (which we already have in the library)

fmxFranky commented 2 months ago

Hi! Firstly, thanks for your reply!

Regarding the action prediction of MAT, in the training stage, since we can directly sample from the off-policy buffer or use on-policy samples, the joint action of agents at a timestep is determined. That is to say, we can directly use the sampled actions as the label of the Transformer and then calculate the RL loss (such as PPO loss). At this time, the agent order can be arbitrary.

However, in the execution stage, as you think, we first need to determine an agent order and then output actions agent by agent autoregressively. There are also some studies on how to choose this order. But for MAT, the impact is not significant, and we can ignore the impact of the output order.

Finally, regarding your concern about centralized policies. Indeed, for real-world applications, this CTCE mode cannot be called a multi-agent RL algorithm to some extent. But some recent studies have proved that the performance of the decentralized policy formed by further policy distillation based on this centralized policy has also achieved stronger performance compared to training directly using the CTDE paradigm. On the other hand, I think this structure can also be regarded as a MARL paradigm that allows implicit communication between agents, and there are many research studies on MARL with communicaiton. Therefore, I personally think that MAT still has certain implementation value.

I have also looked at the implementation of GNN in the repo. Analogously to MAT, it can indeed be regarded as its Encoder part. What I currently want to figure out is how to implement the neural network of the Decoder part. From another view, regarding MAT as a single-agent version of a centralized policy and then going back to the dependent TorchRL, MAT is completely different from the offline Decision Transformer implemented by TorchRL. If you can give me some advice on how to implement MAT in TorchRL or BenchMARL, I will be very grateful to you!

matteobettini commented 2 months ago

Ok got your points, thanks!

so it seems that the core is implementing the decoder since we already can do the encoder

i think the main problem there is that benchmarl policies output the parameters of distributions and the sampling of the actions is done later. Thus, it seems that to run this in the proposed implementation we would need n_agents forward passes for each environment step. But the way benchmarl works is that we take one forward pass for all agents per step.

Therefore in this forward pass we cannot compute actions autoregressively as actions are sampled outside the model. What we could do is output the logits of agents autoregressively in one pass (one agent at a time)

this would change the paradigm to be autoregressive on logits and not on actions .

Now maybe one question that we need to ask is how are actions/logits tokenized in continuous and discrete domains

fmxFranky commented 2 months ago

Your understanding is correct, and the way of tokenizing logits is different in training and inferencing phases. You can read the following code given encoder's embeddings which is copied in MAT's repo for more details:

import torch
from torch.distributions import Categorical, Normal
from torch.nn import functional as F

def discrete_autoregreesive_act(decoder, obs_rep, obs, batch_size, n_agent, action_dim, tpdv,
                                available_actions=None, deterministic=False):
    shifted_action = torch.zeros((batch_size, n_agent, action_dim + 1)).to(**tpdv)
    shifted_action[:, 0, 0] = 1
    output_action = torch.zeros((batch_size, n_agent, 1), dtype=torch.long)
    output_action_log = torch.zeros_like(output_action, dtype=torch.float32)

    for i in range(n_agent):
        logit = decoder(shifted_action, obs_rep, obs)[:, i, :]
        if available_actions is not None:
            logit[available_actions[:, i, :] == 0] = -1e10

        distri = Categorical(logits=logit)
        action = distri.probs.argmax(dim=-1) if deterministic else distri.sample()
        action_log = distri.log_prob(action)

        output_action[:, i, :] = action.unsqueeze(-1)
        output_action_log[:, i, :] = action_log.unsqueeze(-1)
        if i + 1 < n_agent:
            shifted_action[:, i + 1, 1:] = F.one_hot(action, num_classes=action_dim)
    return output_action, output_action_log

def discrete_parallel_act(decoder, obs_rep, obs, action, batch_size, n_agent, action_dim, tpdv,
                          available_actions=None):
    one_hot_action = F.one_hot(action.squeeze(-1), num_classes=action_dim)  # (batch, n_agent, action_dim)
    shifted_action = torch.zeros((batch_size, n_agent, action_dim + 1)).to(**tpdv)
    shifted_action[:, 0, 0] = 1
    shifted_action[:, 1:, 1:] = one_hot_action[:, :-1, :]
    logit = decoder(shifted_action, obs_rep, obs)
    if available_actions is not None:
        logit[available_actions == 0] = -1e10

    distri = Categorical(logits=logit)
    action_log = distri.log_prob(action.squeeze(-1)).unsqueeze(-1)
    entropy = distri.entropy().unsqueeze(-1)
    return action_log, entropy

def continuous_autoregreesive_act(decoder, obs_rep, obs, batch_size, n_agent, action_dim, tpdv,
                                  deterministic=False):
    shifted_action = torch.zeros((batch_size, n_agent, action_dim)).to(**tpdv)
    output_action = torch.zeros((batch_size, n_agent, action_dim), dtype=torch.float32)
    output_action_log = torch.zeros_like(output_action, dtype=torch.float32)

    for i in range(n_agent):
        act_mean = decoder(shifted_action, obs_rep, obs)[:, i, :]
        action_std = torch.sigmoid(decoder.log_std) * 0.5

        # log_std = torch.zeros_like(act_mean).to(**tpdv) + decoder.log_std
        # distri = Normal(act_mean, log_std.exp())
        distri = Normal(act_mean, action_std)
        action = act_mean if deterministic else distri.sample()
        action_log = distri.log_prob(action)

        output_action[:, i, :] = action
        output_action_log[:, i, :] = action_log
        if i + 1 < n_agent:
            shifted_action[:, i + 1, :] = action

        # print("act_mean: ", act_mean)
        # print("action: ", action)

    return output_action, output_action_log

def continuous_parallel_act(decoder, obs_rep, obs, action, batch_size, n_agent, action_dim, tpdv):
    shifted_action = torch.zeros((batch_size, n_agent, action_dim)).to(**tpdv)
    shifted_action[:, 1:, :] = action[:, :-1, :]

    act_mean = decoder(shifted_action, obs_rep, obs)
    action_std = torch.sigmoid(decoder.log_std) * 0.5
    distri = Normal(act_mean, action_std)

    # log_std = torch.zeros_like(act_mean).to(**tpdv) + decoder.log_std
    # distri = Normal(act_mean, log_std.exp())

    action_log = distri.log_prob(action)
    entropy = distri.entropy()
    return action_log, entropy

Is there some ways to implement this in BenchMARL and make it work?

matteobettini commented 2 months ago

As I said above, BenchMARL models do not deal with distributions and just output the logits (agnostic of what distribution these will go to)

The way you could implement this in benchmarl is by feeding shifted_logits instead of shifted_action to the decoder so that you remove any distribution dependence from the model

At training time you could either save the collection logits in the buffer and do the computation parallely or recompute them autoregressively as during collection

fmxFranky commented 2 months ago

thank you for your reply, i will consider how to implement it^_^