Closed Icbyone closed 1 year ago
作者你好,在g2anet.py中的forward里有这样一行代码(大约在73行): hard_weights = hard_weights[:, 1].view(-1, self.args.n_agents, 1, self.args.n_agents - 1) 对于前半部分hard_weights = hard_weights[:, 1],我理解的是从每个batch_size中取出了第二列,请问为什么要这样取呢?
hard_weights = hard_weights[:, 1].view(-1, self.args.n_agents, 1, self.args.n_agents - 1)
hard_weights = hard_weights[:, 1]
研究了一下发现是我改的时候弄错维度了,现在明白了
作者你好,在g2anet.py中的forward里有这样一行代码(大约在73行):
hard_weights = hard_weights[:, 1].view(-1, self.args.n_agents, 1, self.args.n_agents - 1)
对于前半部分hard_weights = hard_weights[:, 1]
,我理解的是从每个batch_size中取出了第二列,请问为什么要这样取呢?