allanj / pytorch_neural_crf

Pytorch implementation of LSTM/BERT-CRF for named entity recognition
359 stars 62 forks source link

About orig_to_token_index padding problem #31

Closed EternalEep closed 3 years ago

EternalEep commented 3 years ago

I have some problems in your dataset preprocessing code.

In the model/transformers_embedder.py file, we have TransformersEmbedder class. For the forward function, the return is (marked as code [1])

return torch.gather(word_rep[:, :, :], 1, orig_to_token_index.unsqueeze(-1).expand(batch_size, max_sent_len, rep_size))

As we know, this is a gather function which will use orig_to_token_index as the index. But for the batch, different sentences will have different length. So I observed that in your preprocessing batch sample code, which is in the data/transformers_dataset.py(marked as code [2]) orig_to_tok_index = feature.orig_to_tok_index + [0] * padding_word_len label_ids = feature.label_ids + [0] * padding_word_len

You use [0] * padding_word_len to pad orig_to_tok_index and label_ids. So if we run code [1], it will get 0 index in the padding position, so we get [CLS] embedding vector from the Bert in the padding position. And then predict [CLS] embedding to the 0 label index(PAD's index in your code).

I think it's a little wired to use 0 padding in orig_to_token_index list and then predict [CLS] to PAD label, do we need to change it? Or just I misunderstanded the logic?

Hope to receive your explanation. Thank you very much!

allanj commented 3 years ago

For code[2], it is practically same if we use some other index instead of 0.

Because in CRF, we will take the subsequence representation instead of the full representation.

Particularly in this step (https://github.com/allanj/pytorch_neural_crf/blob/master/src/model/module/linear_crf_inferencer.py#L102), if the word_seq_len is only 8, only the max length is 10, we only take the scores up to 8. The rest will not be considered.

EternalEep commented 3 years ago

I have understood what you have meant. So you use mask in CRF layer. Thank you for your explanation for me. It seems that I should read the CRF carefully about the CRF code, thank you, Allan.