facebookresearch / hanabi_SAD

Simplified Action Decoder for Deep Multi-Agent Reinforcement Learning
Other
96 stars 35 forks source link

Details of Implementation #24

Closed 0xJchen closed 3 years ago

0xJchen commented 3 years ago

I found it hard to understand the purpose the following two lines in r2d2.py legal_q = (1 + q - q.min()) * legal_move in R2D2Net.forward(), and legal_adv = (1 + adv - adv.min()) * legal_move in R2D2Agent.greedy_act() Take the first snippet as an example, assume parameters follow tool/dev.sh.

# in R2D2Net.forward():
...
q = self._duel(v, a, legal_move)

# q: [seq_len, batch, num_action]
# action: [seq_len, batch]
qa = q.gather(2, action.unsqueeze(2)).squeeze(2)#for target_net, its greedy_a

assert q.size() == legal_move.size()
print("q,legal_q",q.shape,legal_q.shape)
#shape: q,legal_q [1, 80, 21] [1, 80, 21], legal_move  [1, 80, 21]
legal_q = (1 + q - q.min()) * legal_move
# greedy_action: [seq_len, batch]
greedy_action = legal_q.argmax(2).detach()

I agree the q valued should be filtered by legal_move, but what's the purpose of 1 + q - q.min()?

hengyuan-hu commented 3 years ago

It is to make q > 0 so that the order gets preserved after applying the legal_move mask. Imagining if not all q values are > 0, then we may have q = [0.2, -0.5, -0.1], legal_move = [0.0, 1.0, 1.0], then (q * legal_move).argmax() will return an illegal move.

0xJchen commented 3 years ago

Thanks for the clarification!