liuwei1206 / LEBERT

Code for the ACL2021 paper "Lexicon Enhanced Chinese Sequence Labelling Using BERT Adapter"
336 stars 60 forks source link

this line code I don't understand #35

Closed lvjiujin closed 2 years ago

lvjiujin commented 2 years ago

When you complete the bi-linear attention in the class BertLayer like the following code :

    if self.has_word_attn:
                assert input_word_mask is not None
        # transform
        # paper's (W_1 * X + b_1) X:(N,L,W,D) W_1:(D,d_c) => word_outputs (N,L,W,H)
        word_outputs = self.word_transform(input_word_embeddings)  # [N, L, W, D]
        # paper's tanh(W_1 * X + b_1) => word_outputs(N,L,W,d_c)
        word_outputs = self.act(word_outputs)
        # paper's W2 * (tanh(W_1 * X + b_1)) + b2 , W_2(d_c,d_c) => word_outputs(N,L,W,H)
        word_outputs = self.word_word_weight(word_outputs)
        word_outputs = self.dropout(word_outputs)

        # attention_output = attention_output.unsqueeze(2) # [N, L, D] -> [N, L, 1, D]
        # W_attn: the weight matrix of bilinear attention: (d_c, d_c)
        # layer_output.unsqueeze(2) -> (batch_size, seq_length, 1, hidden_size) =>(N,L,1,H)
        # alhpa => (batch_size, seq_length, 1, hidden_size) : (N,L,1,H)
        alpha = torch.matmul(layer_output.unsqueeze(2), self.attn_W)  # [N, L, 1, H]
        # word_outputs:(N,L,W,H)  transpose(word_outputs, 2, 3) -> (N,L,H,W)
        alpha = torch.matmul(alpha, torch.transpose(word_outputs, 2, 3))  # [N, L, 1, W], bi-linear transform end.
        alpha = alpha.squeeze()  # [N, L, W]
        alpha = alpha + (1 - input_word_mask.float()) * (-10000.0)
        alpha = torch.nn.Softmax(dim=-1)(alpha)  # [N, L, W]
        alpha = alpha.unsqueeze(-1)  # [N, L, W, 1]
        weighted_word_embedding = torch.sum(word_outputs * alpha, dim=2)  # [N, L, D]
        layer_output = layer_output + weighted_word_embedding

        layer_output = self.dropout(layer_output)
        layer_output = self.fuse_layer_norm(layer_output)

I don't understand the code alpha = alpha + (1 - input_word_mask.float()) * (-10000.0) can you give me an explanation?

liuwei1206 commented 2 years ago

Hi,

Every character will match a different number of words, which up to the value of max_word_num. We use input_word_mask to denote which words are valid. That code is used to make sure only valid words contributes to the model. Hope it is helpful.

lvjiujin commented 2 years ago

Hi,

Every character will match a different number of words, which up to the value of max_word_num. We use input_word_mask to denote which words are valid. That code is used to make sure only valid words contributes to the model. Hope it is helpful.

Thank you, I see, because after the bi-linear transform, the word_outputs becomes alpha,which result in the vague infomation of matched word and unmatched word, so use this line code to weaken the unmatched word infomation and attention.