Theohhhu / UPDeT

Official Implementation of 'UPDeT: Universal Multi-agent Reinforcement Learning via Policy Decoupling with Transformers' ICLR 2021(spotlight)
MIT License
129 stars 17 forks source link

In which part do you implement policy decoupling #12

Closed donutQQ closed 2 years ago

donutQQ commented 2 years ago

Hello, I am very interested in your work! I have learned the code, especially the class "TransformerAggregationAgent". But I have not found where you implement the policy decoupling. The only thing I find is q_agg = torch.mean(outputs, 1) q = self.q_linear(q_agg)

I am confused that you calculatte the mean along the action dimension and then map the result back to the actions. Can you please explain the motivation of this part. Really look forward to your reply.

Thanks!

hhhusiyi-monash commented 2 years ago

Hi there,

Thanks for your interest. TransformerAggregationAgent is a transformer-based agent without policy decoupling strategy. In figure 4(a) of our paper, you could find that without policy decoupling, the transformer-based agent performs even worse than classical GRU/LSTM, which demonstrates the effectiveness of this strategy.

Any further concern is welcome.

ouyangshixiong commented 2 years ago

I am not the author of this algorithm, but I carefully read through all code, I think the code of "policy decoupling" is below:

    def forward(self, inputs, hidden_state, task_enemy_num, task_ally_num):
        outputs, _ = self.transformer.forward(inputs, hidden_state, None)
        # first output for 6 action (no_op stop up down left right)
        q_basic_actions = self.q_basic(outputs[:, 0, :])

        # last dim for hidden state
        h = outputs[:, -1:, :]

        q_enemies_list = []

        # each enemy has an output Q
        for i in range(task_enemy_num):
            q_enemy = self.q_basic(outputs[:, 1 + i, :])
            q_enemy_mean = torch.mean(q_enemy, 1, True)
            q_enemies_list.append(q_enemy_mean)

        # concat enemy Q over all enemies
        q_enemies = torch.stack(q_enemies_list, dim=1).squeeze()

        # concat basic action Q with enemy attack Q
        q = torch.cat((q_basic_actions, q_enemies), 1)

        return q, h

As paper said, it used Transformer to process input(obs), so "inputs" should be obs. outputs should be all agents' Raw Value(In-short "R"), include enemies'(Figure 7 in the paper).

The author use HEATMAP to explain the relationship between self-attention matrix and final stragegy(Figure 6 in the paper)

hhhusiyi-monash commented 2 years ago

Thanks for your detailed explanation. And I pinned this issue for people who have the same confusion.