mickeysjm / R-BERT

Pytorch re-implementation of R-BERT model
GNU General Public License v3.0
66 stars 15 forks source link

Question about processing entity output in model #4

Open liuyaduo opened 4 years ago

liuyaduo commented 4 years ago

Hi! In original paper,I found that they apply an activation operation and add a fully connected layer after the average operation to get a vector representation for each of the two target entities.

def extract_entity(sequence_output, e_mask):
       extended_e_mask = e_mask.unsqueeze(1)
       extended_e_mask = torch.bmm(
                extended_e_mask.float(), sequence_output).squeeze(1)
       return extended_e_mask.float()

e1_h = self.ent_dropout(extract_entity(sequence_output, e1_mask))
e2_h = self.ent_dropout(extract_entity(sequence_output, e2_mask))
context = self.cls_dropout(pooled_output)
pooled_output = torch.cat([context, e1_h, e2_h], dim=-1)

why don't I find activation and fully connected layer in model.py?

mickeysjm commented 4 years ago

Hi @liuyaduo,

Thanks for the detailed comparison. Indeed, this code does not have this additional fully connected layer + activation function. You can easily add this function as follows:

def __init__(self, config):
        super(BertForSequenceClassification, self).__init__(config)
        self.num_labels = config.num_labels
        self.bert = BertModel(config)
        self.cls_dropout = nn.Dropout(0.1)  # dropout on CLS transformed token embedding
        self.ent_dropout = nn.Dropout(0.5)  # dropout on average entity embedding
        self.ffn = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()
        self.classifier = nn.Linear(config.hidden_size*3, self.config.num_labels)

def forward(self, ......):
        e1_h = self.ent_dropout(extract_entity(sequence_output, e1_mask))
        e2_h = self.ent_dropout(extract_entity(sequence_output, e2_mask))
        e1_h = self.ffn(self.activation(e1_h))
        e2_h = self.ffn(self.activation(e2_h))
        context = self.cls_dropout(pooled_output)
        pooled_output = torch.cat([context, e1_h, e2_h], dim=-1)

I am not sure whether this will improve the performance but you can easily try it.

Hope this answers your question.