Open luojq-sysysdcs opened 4 years ago
class Multi_Head_Attention(nn.Module): def init(self, dim_model, num_head, dropout=0.0): super(Multi_Head_Attention, self).init() self.num_head = num_head assert dim_model % num_head == 0 self.dim_head = dim_model // self.num_head self.fc_Q = nn.Linear(dim_model, num_head self.dim_head) self.fc_K = nn.Linear(dim_model, num_head self.dim_head) self.fc_V = nn.Linear(dim_model, num_head self.dim_head) self.attention = Scaled_Dot_Product_Attention() self.fc = nn.Linear(num_head self.dim_head, dim_model) self.dropout = nn.Dropout(dropout) self.layer_norm = nn.LayerNorm(dim_model)
def forward(self, x): batch_size = x.size(0) Q = self.fc_Q(x) K = self.fc_K(x) V = self.fc_V(x) Q = Q.view(batch_size * self.num_head, -1, self.dim_head) K = K.view(batch_size * self.num_head, -1, self.dim_head) V = V.view(batch_size * self.num_head, -1, self.dim_head) # if mask: # TODO # mask = mask.repeat(self.num_head, 1, 1) # TODO change this scale = K.size(-1) ** -0.5 # 缩放因子 context = self.attention(Q, K, V, scale) context = context.view(batch_size, -1, self.dim_head * self.num_head) out = self.fc(context) out = self.dropout(out) out = out + x # 残差连接 out = self.layer_norm(out) return out
这里应该是要先转置在view的
class Multi_Head_Attention(nn.Module): def init(self, dim_model, num_head, dropout=0.0): super(Multi_Head_Attention, self).init() self.num_head = num_head assert dim_model % num_head == 0 self.dim_head = dim_model // self.num_head self.fc_Q = nn.Linear(dim_model, num_head self.dim_head) self.fc_K = nn.Linear(dim_model, num_head self.dim_head) self.fc_V = nn.Linear(dim_model, num_head self.dim_head) self.attention = Scaled_Dot_Product_Attention() self.fc = nn.Linear(num_head self.dim_head, dim_model) self.dropout = nn.Dropout(dropout) self.layer_norm = nn.LayerNorm(dim_model)
这里应该是要先转置在view的