Closed canghongjian closed 10 months ago
Hi @canghongjian,
In the Select+Answer model, the selector
selects the topK paragraphs, which are then taken by the answerer
which predicts the final answer and supporting paragraphs. The supporting paragraphs are generated by the classification head on top of the answerer, and not the answer confidence, but yes, answer supervision would influence support prediction. Notice that the selector's selected topK paragraphs set is NOT the final support prediction. Its goal is only to reduce the input that the answerer
has to deal with. This is consistent with what we wrote in the paper: 7.2.1 Select+Answer (SA) Model.
I have rechecked the SP F1 number for 2WikiMultihopQA, and can confirm that 99.0 SpF1 number is correct. Here is a script to reproduce this evaluation using our released predictions.
import os
import json
from typing import List, Dict
from metrics.support import SupportMetric
def read_jsonl(file_path: str) -> List[Dict]:
with open(file_path, "r") as file:
instances = [json.loads(line.strip()) for line in file if line.strip()]
return instances
# python download_models.py select_and_answer_model_selector_for_2wikimultihopqa_20k_dataset
selector_predictions_file_path = os.path.join(
"serialization_dir",
"select_and_answer_model_selector_for_2wikimultihopqa_20k_dataset",
"predictions",
"2wikimultihopqa_dev_20k.jsonl"
)
# python download_models.py select_and_answer_model_answerer_for_2wikimultihopqa_20k_dataset
answerer_predictions_file_path = os.path.join(
"serialization_dir",
"select_and_answer_model_answerer_for_2wikimultihopqa_20k_dataset",
"predictions",
# notice: the next line is the output of the selector (string.replace("/", "__") of the selector_predictions_file_path)
"serialization_dir__select_and_answer_model_selector_for_2wikimultihopqa_20k_dataset__predictions__2wikimultihopqa_dev_20k.jsonl"
)
selector_predictions_instances = read_jsonl(selector_predictions_file_path)
answerer_predictions_instances = read_jsonl(answerer_predictions_file_path)
support_metric = SupportMetric()
assert len(selector_predictions_instances) == len(answerer_predictions_instances)
take_predicted_topk_contexts = 5 # This is an HP which is different for each experiment.
# For SA for 2Wiki see that it is 5 in : https://github.com/StonyBrookNLP/musique/blob/main/experiment_configs/select_and_answer_model_answerer_for_2wikimultihopqa_20k_dataset.jsonnet
for selector_prediction_instance, answerer_prediction_instance in zip(
selector_predictions_instances, answerer_predictions_instances
):
# High-level Explanation:
# The selector ranks the paragraphs, the answerer picks `take_predicted_topk_contexts` from it generates answer and supporting predictions.
# This `take_predicted_topk_contexts` is a hyperparameter, which we varied in 3,5,7 (reconfirm exact HPs from the paper).
# This is consistent with what is described in the paper: 7.2.1 : Select+Answer (SA) Model.
# Low-level Explanation:
# The selector ranks the '.contexts' and stores them in the field '.predicted_ordered_contexts' in the prediction jsonl file.
# The answerer now reads context from this file, but from '.predicted_ordered_contexts' field instead of '.contexts' field.
# It picks `take_predicted_topk_contexts` of them: '.predicted_ordered_contexts[:take_predicted_topk_contexts]'
# and generates both the final answer and supporting paragraphs.
# The supporting paragraphs are returned in '.predicted_select_support_indices' field in the prediction file.
# These indices are effectively indices into '.predicted_ordered_contexts' from the selector.
# So we can obtain the ground-truth indices from '.predicted_ordered_contexts' (from selector) using its is_supporting field
# and compare it with '.predicted_select_support_indices' (from answerer).
# Note again that selector only reorders '.contexts', and so the 'is_supporting' field is not affected.
# To confirm that selector's .predicted_ordered_contexts[:take_predicted_topk_contexts] is the same as
# answerer's .contexts, see that this assertion passes:
for context in answerer_prediction_instance["contexts"]:
# The answerer prepends titles before ||, so let's remove them first.
context["paragraph_text"] = context["paragraph_text"].split(" || ", 1)[1]
assert (
selector_prediction_instance["predicted_ordered_contexts"][:take_predicted_topk_contexts] ==
answerer_prediction_instance["contexts"]
)
answerer_predicted_indices = answerer_prediction_instance["predicted_select_support_indices"]
answerer_predicted_contexts = [answerer_prediction_instance["contexts"][index] for index in answerer_predicted_indices]
ground_truth_support_indices = [
# The '.predicted_ordered_contexts' is just an re-ordering of the original '.contexts'.
index for index, context in enumerate(selector_prediction_instance["predicted_ordered_contexts"]) if context["is_supporting"]
]
predicted_support_indices = [
selector_prediction_instance["predicted_ordered_contexts"].index(answerer_prediction_instance["contexts"][index])
for index in answerer_prediction_instance["predicted_select_support_indices"]
]
support_metric(predicted_support_indices, ground_truth_support_indices)
sp_em, sp_f1 = support_metric.get_metric(reset=True)
print("Method 1:")
print(f"SP-EM: {round(sp_em*100, 1)}") # 97.0
print(f"SP-F1: {round(sp_f1*100, 1)}") # 99.0 (reported)
# There is also another way you can obtain the same results
for answerer_prediction_instance in answerer_predictions_instances:
predicted_support_indices = answerer_prediction_instance["predicted_select_support_indices"]
contexts = answerer_prediction_instance["contexts"] + answerer_prediction_instance["skipped_support_contexts"]
ground_truth_support_indices = [index for index, context in enumerate(contexts) if context["is_supporting"]]
support_metric(predicted_support_indices, ground_truth_support_indices)
sp_em, sp_f1 = support_metric.get_metric(reset=True)
print("Method 2:")
print(f"SP-EM: {round(sp_em*100, 1)}") # 97.0
print(f"SP-F1: {round(sp_f1*100, 1)}") # 99.0 (reported)
Got it. Considering the additional factors introduced by the answer, I report the scores only from the selector trained on full datasets, which are shown as follows:
The calculation codes are:
import json
def calculate_em_f1(predicted_support_idxs, gold_support_idxs):
# Taken from hotpot_eval
cur_sp_pred = set(map(int, predicted_support_idxs))
gold_sp_pred = set(map(int, gold_support_idxs))
tp, fp, fn = 0, 0, 0
for e in cur_sp_pred:
if e in gold_sp_pred:
tp += 1
else:
fp += 1
for e in gold_sp_pred:
if e not in cur_sp_pred:
fn += 1
prec = 1.0 * tp / (tp + fp) if tp + fp > 0 else 0.0
recall = 1.0 * tp / (tp + fn) if tp + fn > 0 else 0.0
f1 = 2 * prec * recall / (prec + recall) if prec + recall > 0 else 0.0
em = 1.0 if fp + fn == 0 else 0.0
# In case everything is empty, set both f1, em to be 1.0.
# Without this change, em gets 1 and f1 gets 0
if not cur_sp_pred and not gold_sp_pred:
f1, em = 1.0, 1.0
f1, em = 1.0, 1.0
return f1, em
#url = '/root/musique/serialization_dir/select_and_answer_model_selector_for_hotpotqa_20k/predictions/hotpotqa_dev_20k.jsonl'
url = '/root/musique/serialization_dir/select_and_answer_model_selector_for_2wikimultihopqa_20k_dataset/predictions/2wikimultihopqa_dev_20k.jsonl'
hotpot_pred = open(url).readlines()
hotpot_pred = [json.loads(item) for item in hotpot_pred]
#hotpot_dev_data = json.load(open('/root/data/hotpotqa/hotpot_dev_distractor_v1.json'))
hotpot_dev_data = json.load(open('/root/data/wikimultihopqa/dev.json'))
sa_pred_hotpot_id2titles = {}
for item in hotpot_pred:
title = [item['predicted_ordered_contexts'][idx]['wikipedia_title'] for idx in range(len(item['predicted_ordered_contexts']))]
sa_pred_hotpot_id2titles[item['id']] = title
em_tot_hotpot, f1_tot_hotpot = [], []
for item in hotpot_dev_data:
sf = []
pred = []
sp_title_set = []
id = item['_id']
hop = 2
if item['type'] == 'bridge_comparison':
hop = 4
# take the top [hop] passages as the predicted relevant passages
sa_pred_hotpot_id2titles[id] = sa_pred_hotpot_id2titles[id][:hop]
for sup in item['supporting_facts']:
sp_title_set.append(sup[0])
for idx, (title, sts) in enumerate(item['context']):
if title in sa_pred_hotpot_id2titles[id]:
pred.append(idx)
if title in sp_title_set:
sf.append(idx)
f1, em = calculate_em_f1(pred, sf)
em_tot_hotpot.append(em)
f1_tot_hotpot.append(f1)
print("em:", sum(em_tot_hotpot) / len(em_tot_hotpot), "f1:", sum(f1_tot_hotpot) / len(f1_tot_hotpot))
Thanks for the clarification again @HarshTrivedi !
Hi, great work @HarshTrivedi ! I found the structure of SA(Select + Answer) model may be a little different from the paper mentioned. After reading your codes, I think the selector of SA model only selects the most K relevant passages, and these K relevant passages are not the final retrieval result. You use a classification head of your answerer to obtain the scores of K relevant passages and keep those whose scores are higher than 0.5. That means your answerer also contributes to the relevant passages retrieval task. Do I understand correctly? Additionally, are there any SA results for full-hotpotqa and full-2wikimultihopqa datasets? And I test your SA in 2W-20k (use your
serialization_dir__select_and_answer_model_selector_for_2wikimultihopqa_20k_dataset__predictions__2wikimultihopqa_dev_20k.jsonl
), the Sp F1 metric is97.35
not99.0
in paper. Do I make some mistakes?