ShannonAI / mrc-for-flat-nested-ner

Code for ACL 2020 paper `A Unified MRC Framework for Named Entity Recognition`
657 stars 117 forks source link

A modification of BertQueryNER forward function #81

Closed smiles724 closed 3 years ago

smiles724 commented 3 years ago

Hi, I notice that the calculation in BertQueryNER forward function to compute the span matrix can be simplified.

The original code is start_extend = sequence_heatmap.unsqueeze(2).expand(-1, -1, seq_len, -1) end_extend = sequence_heatmap.unsqueeze(1).expand(-1, seq_len, -1, -1) span_matrix = torch.cat([start_extend, end_extend], 3)

However, since both start_extend and end_extend are using the same variable, we can change this code into span_matrix = sequence_heatmap.unsqueeze(1).expand(-1, seq_len, -1, 2)

Is my understanding correct?