Open JLUGQQ opened 2 years ago
@JLUGQQ: Yes, this is a bug, you can look at issue https://github.com/facebookresearch/BLINK/issues/95 for the solution.
95
Thank you. I have tried this solution before, but it didn't work. Maybe I should change my package version accoring to requirements.txt.
@JLUGQQ: I am able to successfully run both eval on both zeshel and non-zeshel datasets. Feel free to copy and paste your error message here, I'd be glad to take a look.
95
Thank you very much for your help!
I could successfully run train_biencoder. But when I ran eval_biencoder. I encountered this problem. I have changed code according to issue #95
05/06/2022 13:33:00 - INFO - Blink - Getting top 64 predictions.
0%| | 0/2500 [00:00<?, ?it/s]05/06/2022 13:33:00 - INFO - Blink - World size : 16
0%| | 0/2500 [00:00<?, ?it/s]
Traceback (most recent call last):
File "/data/gavin/BLINK-main/blink/biencoder/eval_biencoder.py", line 337, in
@JLUGQQ: Here's what I have in git diff
. Let me know if this helps.
diff --git a/blink/biencoder/nn_prediction.py b/blink/biencoder/nn_prediction.py
index eab90a8..18e50cd 100644
--- a/blink/biencoder/nn_prediction.py
+++ b/blink/biencoder/nn_prediction.py
@@ -55,13 +55,20 @@ def get_topk_predictions(
oid = 0
for step, batch in enumerate(iter_):
batch = tuple(t.to(device) for t in batch)
- context_input, _, srcs, label_ids = batch
+ if is_zeshel:
+ context_input, _, srcs, label_ids = batch
+ else:
+ context_input, _, label_ids = batch
+ srcs = torch.tensor([0] * context_input.size(0), device=device)
+
src = srcs[0].item()
+ cand_encode_list[src] = cand_encode_list[src].to(device)
scores = reranker.score_candidate(
context_input,
None,
- cand_encs=cand_encode_list[src].to(device)
+ cand_encs=cand_encode_list[src]
)
+
values, indicies = scores.topk(top_k)
old_src = src
for i in range(context_input.size(0)):
@@ -93,7 +100,7 @@ def get_topk_predictions(
continue
# add examples in new_data
- cur_candidates = candidate_pool[src][inds]
+ cur_candidates = candidate_pool[srcs[i].item()][inds]
nn_context.append(context_input[i].cpu().tolist())
nn_candidates.append(cur_candidates.cpu().tolist())
nn_labels.append(pointer)
@JLUGQQ: Here's what I have in
git diff
. Let me know if this helps.diff --git a/blink/biencoder/nn_prediction.py b/blink/biencoder/nn_prediction.py index eab90a8..18e50cd 100644 --- a/blink/biencoder/nn_prediction.py +++ b/blink/biencoder/nn_prediction.py @@ -55,13 +55,20 @@ def get_topk_predictions( oid = 0 for step, batch in enumerate(iter_): batch = tuple(t.to(device) for t in batch) - context_input, _, srcs, label_ids = batch + if is_zeshel: + context_input, _, srcs, label_ids = batch + else: + context_input, _, label_ids = batch + srcs = torch.tensor([0] * context_input.size(0), device=device) + src = srcs[0].item() + cand_encode_list[src] = cand_encode_list[src].to(device) scores = reranker.score_candidate( context_input, None, - cand_encs=cand_encode_list[src].to(device) + cand_encs=cand_encode_list[src] ) + values, indicies = scores.topk(top_k) old_src = src for i in range(context_input.size(0)): @@ -93,7 +100,7 @@ def get_topk_predictions( continue # add examples in new_data - cur_candidates = candidate_pool[src][inds] + cur_candidates = candidate_pool[srcs[i].item()][inds] nn_context.append(context_input[i].cpu().tolist()) nn_candidates.append(cur_candidates.cpu().tolist()) nn_labels.append(pointer)
Pity. It still doesn't work. Thanks for your reply. I think I should take a time to debug to find the exact reason. And I will comment if I solve this problem.
KeyError might happen because the validation or test set tries to find their encodings from training set encodings. (e.g. there is a crash when val data - which has the src value 9 - attempts to find their encoding in training encodings, which has src values from 0 to 8.-- it is the reason why there is a key error for value 9) Although there might be multiple solutions to fix this, I recommend saving each encoding in separate files. i.e. the following shell script worked in my case:
python blink/biencoder/eval_biencoder.py --path_to_model models/zeshel/biencoder/pytorch_model.bin --data_path data/zeshel/blink_format --output_path models/zeshel --encode_batch_size 128 --eval_batch_size 1 --top_k 64 --save_topk_result --bert_model bert-large-uncased --mode train --zeshel True --data_parallel --cand_encode_path data/zeshel/cand_enc/cand_enc_train.pt --cand_pool_path data/zeshel/cand_pool/cand_pool_train.pt
python blink/biencoder/eval_biencoder.py --path_to_model models/zeshel/biencoder/pytorch_model.bin --data_path data/zeshel/blink_format --output_path models/zeshel --encode_batch_size 128 --eval_batch_size 1 --top_k 64 --save_topk_result --bert_model bert-large-uncased --mode valid --zeshel True --data_parallel --cand_encode_path data/zeshel/cand_enc/cand_enc_valid.pt --cand_pool_path data/zeshel/cand_pool/cand_pool_valid.pt
python blink/biencoder/eval_biencoder.py --path_to_model models/zeshel/biencoder/pytorch_model.bin --data_path data/zeshel/blink_format --output_path models/zeshel --encode_batch_size 128 --eval_batch_size 1 --top_k 64 --save_topk_result --bert_model bert-large-uncased --mode test --zeshel True --data_parallel --cand_encode_path data/zeshel/cand_enc/cand_enc_test.pt --cand_pool_path data/zeshel/cand_pool/cand_pool_test.pt
when i run eval_biencoder, i encountered this problem:Traceback (most recent call last): File "/data/gavin/blink-el/blink/biencoder/eval_biencoder.py", line 336, in
main(new_params)
File "/data/gavin/blink-el/blink/biencoder/eval_biencoder.py", line 289, in main
save_results,
File "/data/gavin/blink-el/blink/biencoder/nn_prediction.py", line 65, in get_topk_predictions
cand_encs=cand_encode_list[src].to(device)
KeyError: 9