649453932 / Chinese-Text-Classification-Pytorch

中文文本分类,TextCNN,TextRNN,FastText,TextRCNN,BiLSTM_Attention,DPCNN,Transformer,基于pytorch,开箱即用。
MIT License
5.25k stars 1.22k forks source link

transformer #53

Open luojq-sysysdcs opened 4 years ago

luojq-sysysdcs commented 4 years ago

class Multi_Head_Attention(nn.Module): def init(self, dim_model, num_head, dropout=0.0): super(Multi_Head_Attention, self).init() self.num_head = num_head assert dim_model % num_head == 0 self.dim_head = dim_model // self.num_head self.fc_Q = nn.Linear(dim_model, num_head self.dim_head) self.fc_K = nn.Linear(dim_model, num_head self.dim_head) self.fc_V = nn.Linear(dim_model, num_head self.dim_head) self.attention = Scaled_Dot_Product_Attention() self.fc = nn.Linear(num_head self.dim_head, dim_model) self.dropout = nn.Dropout(dropout) self.layer_norm = nn.LayerNorm(dim_model)

def forward(self, x):
    batch_size = x.size(0)
    Q = self.fc_Q(x)
    K = self.fc_K(x)
    V = self.fc_V(x)
    Q = Q.view(batch_size * self.num_head, -1, self.dim_head)
    K = K.view(batch_size * self.num_head, -1, self.dim_head)
    V = V.view(batch_size * self.num_head, -1, self.dim_head)
    # if mask:  # TODO
    #     mask = mask.repeat(self.num_head, 1, 1)  # TODO change this
    scale = K.size(-1) ** -0.5  # 缩放因子
    context = self.attention(Q, K, V, scale)

    context = context.view(batch_size, -1, self.dim_head * self.num_head)
    out = self.fc(context)
    out = self.dropout(out)
    out = out + x  # 残差连接
    out = self.layer_norm(out)
    return out

这里应该是要先转置在view的