z1069614715 / objectdetection_script

一些关于目标检测的脚本的改进思路代码,详细请看readme.md
5.3k stars 476 forks source link

S2Attention bug? #13

Closed vehiclexzh closed 1 year ago

vehiclexzh commented 1 year ago

class S2Attention(nn.Module): def init(self, channels=512): super().init() self.mlp1 = nn.Linear(channels, channels * 3) self.mlp2 = nn.Linear(channels, channels) self.split_attention = SplitAttention()

这一行: self.split_attention = SplitAttention() 好像要改为: self.split_attention = SplitAttention(channels) 否则会报“矩阵乘法”相关的错误

(即,class SplitAttention(nn.Module): 里的 hat_a = self.mlp2(self.gelu(self.mlp1(a))) # bs,kc 不运行?)