Closed vehiclexzh closed 1 year ago
class S2Attention(nn.Module): def init(self, channels=512): super().init() self.mlp1 = nn.Linear(channels, channels * 3) self.mlp2 = nn.Linear(channels, channels) self.split_attention = SplitAttention()
这一行: self.split_attention = SplitAttention() 好像要改为: self.split_attention = SplitAttention(channels) 否则会报“矩阵乘法”相关的错误
(即,class SplitAttention(nn.Module): 里的 hat_a = self.mlp2(self.gelu(self.mlp1(a))) # bs,kc 不运行?)
class S2Attention(nn.Module): def init(self, channels=512): super().init() self.mlp1 = nn.Linear(channels, channels * 3) self.mlp2 = nn.Linear(channels, channels) self.split_attention = SplitAttention()
这一行: self.split_attention = SplitAttention() 好像要改为: self.split_attention = SplitAttention(channels) 否则会报“矩阵乘法”相关的错误
(即,class SplitAttention(nn.Module): 里的 hat_a = self.mlp2(self.gelu(self.mlp1(a))) # bs,kc 不运行?)