leaderj1001 / Stand-Alone-Self-Attention

Implementing Stand-Alone Self-Attention in Vision Models using Pytorch
MIT License
456 stars 83 forks source link

v_out = torch.cat((v_out_h + self.rel_h, v_out_w + self.rel_w), dim=1) #7

Open scutzck033 opened 4 years ago

scutzck033 commented 4 years ago

Hi, are the following codes wrong? "v_out_h, v_out_w = v_out.split(self.out_channels // 2, dim=1) v_out = torch.cat((v_out_h + self.rel_h, v_out_w + self.rel_w), dim=1)" shouldn't it be something as follows? "k_out_h, k_out_w = k_out.split(self.out_channels // 2, dim=1) k_out = torch.cat((k_out_h + self.rel_h, k_out_w + self.rel_w), dim=1)" Because the relative distance embedding is supposed to added to the "key" instead of the "value", right?

leaderj1001 commented 4 years ago

Thanks for your comments. I'll check it and training it again. Thank you :)

yassouali commented 4 years ago

I think you're right, the relative embedding are added to the keys and not the values. As detailed in eq. 3 of the paper. But given that we unfold the keys, we need to compute the shifted keys before that. So we first shift the keys based on the relative positions and then unfold them.

k_out_h, k_out_w = k_out.split(self.out_channels // 2, dim=1)
k_out = torch.cat((k_out_h + self.rel_h, k_out_w + self.rel_w), dim=1)
k_out = F.unfold(k_out, kernel_size=(self.kernel_size, self.kernel_size), stride=self.stride)