facebookresearch / BLINK

Entity Linker solution
MIT License
1.17k stars 231 forks source link

KeyError #115

Open JLUGQQ opened 2 years ago

JLUGQQ commented 2 years ago

when i run eval_biencoder, i encountered this problem:Traceback (most recent call last): File "/data/gavin/blink-el/blink/biencoder/eval_biencoder.py", line 336, in main(new_params) File "/data/gavin/blink-el/blink/biencoder/eval_biencoder.py", line 289, in main save_results, File "/data/gavin/blink-el/blink/biencoder/nn_prediction.py", line 65, in get_topk_predictions cand_encs=cand_encode_list[src].to(device) KeyError: 9

abhinavkulkarni commented 2 years ago

@JLUGQQ: Yes, this is a bug, you can look at issue https://github.com/facebookresearch/BLINK/issues/95 for the solution.

JLUGQQ commented 2 years ago

95

Thank you. I have tried this solution before, but it didn't work. Maybe I should change my package version accoring to requirements.txt.

abhinavkulkarni commented 2 years ago

@JLUGQQ: I am able to successfully run both eval on both zeshel and non-zeshel datasets. Feel free to copy and paste your error message here, I'd be glad to take a look.

JLUGQQ commented 2 years ago

95

Thank you very much for your help! I could successfully run train_biencoder. But when I ran eval_biencoder. I encountered this problem. I have changed code according to issue #95
05/06/2022 13:33:00 - INFO - Blink - Getting top 64 predictions. 0%| | 0/2500 [00:00<?, ?it/s]05/06/2022 13:33:00 - INFO - Blink - World size : 16 0%| | 0/2500 [00:00<?, ?it/s] Traceback (most recent call last): File "/data/gavin/BLINK-main/blink/biencoder/eval_biencoder.py", line 337, in main(new_params) File "/data/gavin/BLINK-main/blink/biencoder/eval_biencoder.py", line 289, in main save_results, File "/data/gavin/BLINK-main/blink/biencoder/nn_prediction.py", line 65, in get_topk_predictions cand_encs=cand_encode_list[src].to(device) KeyError: 12

abhinavkulkarni commented 2 years ago

@JLUGQQ: Here's what I have in git diff. Let me know if this helps.

diff --git a/blink/biencoder/nn_prediction.py b/blink/biencoder/nn_prediction.py
index eab90a8..18e50cd 100644
--- a/blink/biencoder/nn_prediction.py
+++ b/blink/biencoder/nn_prediction.py
@@ -55,13 +55,20 @@ def get_topk_predictions(
     oid = 0
     for step, batch in enumerate(iter_):
         batch = tuple(t.to(device) for t in batch)
-        context_input, _, srcs, label_ids = batch
+        if is_zeshel:
+            context_input, _, srcs, label_ids = batch
+        else:
+            context_input, _, label_ids = batch
+            srcs = torch.tensor([0] * context_input.size(0), device=device)
+
         src = srcs[0].item()
+        cand_encode_list[src] = cand_encode_list[src].to(device)
         scores = reranker.score_candidate(
             context_input, 
             None, 
-            cand_encs=cand_encode_list[src].to(device)
+            cand_encs=cand_encode_list[src]
         )
+
         values, indicies = scores.topk(top_k)
         old_src = src
         for i in range(context_input.size(0)):
@@ -93,7 +100,7 @@ def get_topk_predictions(
                 continue

             # add examples in new_data
-            cur_candidates = candidate_pool[src][inds]
+            cur_candidates = candidate_pool[srcs[i].item()][inds]
             nn_context.append(context_input[i].cpu().tolist())
             nn_candidates.append(cur_candidates.cpu().tolist())
             nn_labels.append(pointer)
JLUGQQ commented 2 years ago

@JLUGQQ: Here's what I have in git diff. Let me know if this helps.

diff --git a/blink/biencoder/nn_prediction.py b/blink/biencoder/nn_prediction.py
index eab90a8..18e50cd 100644
--- a/blink/biencoder/nn_prediction.py
+++ b/blink/biencoder/nn_prediction.py
@@ -55,13 +55,20 @@ def get_topk_predictions(
     oid = 0
     for step, batch in enumerate(iter_):
         batch = tuple(t.to(device) for t in batch)
-        context_input, _, srcs, label_ids = batch
+        if is_zeshel:
+            context_input, _, srcs, label_ids = batch
+        else:
+            context_input, _, label_ids = batch
+            srcs = torch.tensor([0] * context_input.size(0), device=device)
+
         src = srcs[0].item()
+        cand_encode_list[src] = cand_encode_list[src].to(device)
         scores = reranker.score_candidate(
             context_input, 
             None, 
-            cand_encs=cand_encode_list[src].to(device)
+            cand_encs=cand_encode_list[src]
         )
+
         values, indicies = scores.topk(top_k)
         old_src = src
         for i in range(context_input.size(0)):
@@ -93,7 +100,7 @@ def get_topk_predictions(
                 continue

             # add examples in new_data
-            cur_candidates = candidate_pool[src][inds]
+            cur_candidates = candidate_pool[srcs[i].item()][inds]
             nn_context.append(context_input[i].cpu().tolist())
             nn_candidates.append(cur_candidates.cpu().tolist())
             nn_labels.append(pointer)

Pity. It still doesn't work. Thanks for your reply. I think I should take a time to debug to find the exact reason. And I will comment if I solve this problem.

yc-song commented 2 years ago

KeyError might happen because the validation or test set tries to find their encodings from training set encodings. (e.g. there is a crash when val data - which has the src value 9 - attempts to find their encoding in training encodings, which has src values from 0 to 8.-- it is the reason why there is a key error for value 9) Although there might be multiple solutions to fix this, I recommend saving each encoding in separate files. i.e. the following shell script worked in my case:

python blink/biencoder/eval_biencoder.py --path_to_model models/zeshel/biencoder/pytorch_model.bin --data_path data/zeshel/blink_format --output_path models/zeshel --encode_batch_size 128 --eval_batch_size 1 --top_k 64 --save_topk_result --bert_model bert-large-uncased --mode train --zeshel True --data_parallel --cand_encode_path data/zeshel/cand_enc/cand_enc_train.pt --cand_pool_path data/zeshel/cand_pool/cand_pool_train.pt

python blink/biencoder/eval_biencoder.py --path_to_model models/zeshel/biencoder/pytorch_model.bin --data_path data/zeshel/blink_format --output_path models/zeshel --encode_batch_size 128 --eval_batch_size 1 --top_k 64 --save_topk_result --bert_model bert-large-uncased --mode valid --zeshel True --data_parallel --cand_encode_path data/zeshel/cand_enc/cand_enc_valid.pt --cand_pool_path data/zeshel/cand_pool/cand_pool_valid.pt

python blink/biencoder/eval_biencoder.py --path_to_model models/zeshel/biencoder/pytorch_model.bin --data_path data/zeshel/blink_format --output_path models/zeshel --encode_batch_size 128 --eval_batch_size 1 --top_k 64 --save_topk_result --bert_model bert-large-uncased --mode test --zeshel True --data_parallel --cand_encode_path data/zeshel/cand_enc/cand_enc_test.pt --cand_pool_path data/zeshel/cand_pool/cand_pool_test.pt