StonyBrookNLP / musique

Repository for MuSiQue: Multi-hop Questions via Single-hop Question Composition, TACL 2022
Creative Commons Attribution 4.0 International
74 stars 8 forks source link

Questions about SA model #3

Closed canghongjian closed 10 months ago

canghongjian commented 11 months ago

Hi, great work @HarshTrivedi ! I found the structure of SA(Select + Answer) model may be a little different from the paper mentioned. After reading your codes, I think the selector of SA model only selects the most K relevant passages, and these K relevant passages are not the final retrieval result. You use a classification head of your answerer to obtain the scores of K relevant passages and keep those whose scores are higher than 0.5. That means your answerer also contributes to the relevant passages retrieval task. Do I understand correctly? Additionally, are there any SA results for full-hotpotqa and full-2wikimultihopqa datasets? And I test your SA in 2W-20k (use your serialization_dir__select_and_answer_model_selector_for_2wikimultihopqa_20k_dataset__predictions__2wikimultihopqa_dev_20k.jsonl), the Sp F1 metric is 97.35 not 99.0 in paper. Do I make some mistakes?

HarshTrivedi commented 11 months ago

Hi @canghongjian,

In the Select+Answer model, the selector selects the topK paragraphs, which are then taken by the answerer which predicts the final answer and supporting paragraphs. The supporting paragraphs are generated by the classification head on top of the answerer, and not the answer confidence, but yes, answer supervision would influence support prediction. Notice that the selector's selected topK paragraphs set is NOT the final support prediction. Its goal is only to reduce the input that the answerer has to deal with. This is consistent with what we wrote in the paper: 7.2.1 Select+Answer (SA) Model.

I have rechecked the SP F1 number for 2WikiMultihopQA, and can confirm that 99.0 SpF1 number is correct. Here is a script to reproduce this evaluation using our released predictions.

import os
import json
from typing import List, Dict
from metrics.support import SupportMetric

def read_jsonl(file_path: str) -> List[Dict]:
    with open(file_path, "r") as file:
        instances = [json.loads(line.strip()) for line in file if line.strip()]
    return instances

# python download_models.py select_and_answer_model_selector_for_2wikimultihopqa_20k_dataset
selector_predictions_file_path = os.path.join(
    "serialization_dir",
    "select_and_answer_model_selector_for_2wikimultihopqa_20k_dataset",
    "predictions",
    "2wikimultihopqa_dev_20k.jsonl"
)
# python download_models.py select_and_answer_model_answerer_for_2wikimultihopqa_20k_dataset
answerer_predictions_file_path = os.path.join(
    "serialization_dir",
    "select_and_answer_model_answerer_for_2wikimultihopqa_20k_dataset",
    "predictions",
    # notice: the next line is the output of the selector (string.replace("/", "__") of the selector_predictions_file_path)
    "serialization_dir__select_and_answer_model_selector_for_2wikimultihopqa_20k_dataset__predictions__2wikimultihopqa_dev_20k.jsonl"
)

selector_predictions_instances = read_jsonl(selector_predictions_file_path)
answerer_predictions_instances = read_jsonl(answerer_predictions_file_path)
support_metric = SupportMetric()

assert len(selector_predictions_instances) == len(answerer_predictions_instances)

take_predicted_topk_contexts = 5 # This is an HP which is different for each experiment.
# For SA for 2Wiki see that it is 5 in : https://github.com/StonyBrookNLP/musique/blob/main/experiment_configs/select_and_answer_model_answerer_for_2wikimultihopqa_20k_dataset.jsonnet
for selector_prediction_instance, answerer_prediction_instance in zip(
    selector_predictions_instances, answerer_predictions_instances
):
    # High-level Explanation: 
    # The selector ranks the paragraphs, the answerer picks `take_predicted_topk_contexts` from it generates answer and supporting predictions.
    # This `take_predicted_topk_contexts` is a hyperparameter, which we varied in 3,5,7 (reconfirm exact HPs from the paper).
    # This is consistent with what is described in the paper: 7.2.1 : Select+Answer (SA) Model.

    # Low-level Explanation:
    # The selector ranks the '.contexts' and stores them in the field '.predicted_ordered_contexts' in the prediction jsonl file.
    # The answerer now reads context from this file, but from '.predicted_ordered_contexts' field instead of '.contexts' field.
    # It picks `take_predicted_topk_contexts` of them: '.predicted_ordered_contexts[:take_predicted_topk_contexts]'
    # and generates both the final answer and supporting paragraphs.
    # The supporting paragraphs are returned in '.predicted_select_support_indices' field in the prediction file.
    # These indices are effectively indices into '.predicted_ordered_contexts' from the selector.
    # So we can obtain the ground-truth indices from '.predicted_ordered_contexts' (from selector) using its is_supporting field
    # and compare it with '.predicted_select_support_indices' (from answerer).
    # Note again that selector only reorders '.contexts', and so the 'is_supporting' field is not affected.

    # To confirm that selector's .predicted_ordered_contexts[:take_predicted_topk_contexts] is the same as
    # answerer's .contexts, see that this assertion passes:
    for context in answerer_prediction_instance["contexts"]:
        # The answerer prepends titles before ||, so let's remove them first.
        context["paragraph_text"] = context["paragraph_text"].split(" || ", 1)[1]
    assert (
        selector_prediction_instance["predicted_ordered_contexts"][:take_predicted_topk_contexts] ==
        answerer_prediction_instance["contexts"]
    )

    answerer_predicted_indices = answerer_prediction_instance["predicted_select_support_indices"]
    answerer_predicted_contexts = [answerer_prediction_instance["contexts"][index] for index in answerer_predicted_indices]

    ground_truth_support_indices = [
        # The '.predicted_ordered_contexts' is just an re-ordering of the original '.contexts'.
        index for index, context in enumerate(selector_prediction_instance["predicted_ordered_contexts"]) if context["is_supporting"]
    ]
    predicted_support_indices = [
        selector_prediction_instance["predicted_ordered_contexts"].index(answerer_prediction_instance["contexts"][index])
        for index in answerer_prediction_instance["predicted_select_support_indices"]
    ]
    support_metric(predicted_support_indices, ground_truth_support_indices)

sp_em, sp_f1 = support_metric.get_metric(reset=True)
print("Method 1:")
print(f"SP-EM: {round(sp_em*100, 1)}") # 97.0
print(f"SP-F1: {round(sp_f1*100, 1)}") # 99.0 (reported)

# There is also another way you can obtain the same results
for answerer_prediction_instance in answerer_predictions_instances:
    predicted_support_indices = answerer_prediction_instance["predicted_select_support_indices"]
    contexts = answerer_prediction_instance["contexts"] + answerer_prediction_instance["skipped_support_contexts"]
    ground_truth_support_indices = [index for index, context in enumerate(contexts) if context["is_supporting"]]
    support_metric(predicted_support_indices, ground_truth_support_indices)

sp_em, sp_f1 = support_metric.get_metric(reset=True)
print("Method 2:")
print(f"SP-EM: {round(sp_em*100, 1)}") # 97.0
print(f"SP-F1: {round(sp_f1*100, 1)}") # 99.0 (reported)
canghongjian commented 11 months ago

Got it. Considering the additional factors introduced by the answer, I report the scores only from the selector trained on full datasets, which are shown as follows:

image

The calculation codes are:

import json
def calculate_em_f1(predicted_support_idxs, gold_support_idxs):
    # Taken from hotpot_eval
    cur_sp_pred = set(map(int, predicted_support_idxs))
    gold_sp_pred = set(map(int, gold_support_idxs))
    tp, fp, fn = 0, 0, 0
    for e in cur_sp_pred:
        if e in gold_sp_pred:
            tp += 1
        else:
            fp += 1
    for e in gold_sp_pred:
        if e not in cur_sp_pred:
            fn += 1
    prec = 1.0 * tp / (tp + fp) if tp + fp > 0 else 0.0
    recall = 1.0 * tp / (tp + fn) if tp + fn > 0 else 0.0
    f1 = 2 * prec * recall / (prec + recall) if prec + recall > 0 else 0.0
    em = 1.0 if fp + fn == 0 else 0.0

    # In case everything is empty, set both f1, em to be 1.0.
    # Without this change, em gets 1 and f1 gets 0
    if not cur_sp_pred and not gold_sp_pred:
        f1, em = 1.0, 1.0
        f1, em = 1.0, 1.0
    return f1, em

#url = '/root/musique/serialization_dir/select_and_answer_model_selector_for_hotpotqa_20k/predictions/hotpotqa_dev_20k.jsonl'
url = '/root/musique/serialization_dir/select_and_answer_model_selector_for_2wikimultihopqa_20k_dataset/predictions/2wikimultihopqa_dev_20k.jsonl'
hotpot_pred = open(url).readlines()
hotpot_pred = [json.loads(item) for item in hotpot_pred]
#hotpot_dev_data = json.load(open('/root/data/hotpotqa/hotpot_dev_distractor_v1.json'))
hotpot_dev_data = json.load(open('/root/data/wikimultihopqa/dev.json'))
sa_pred_hotpot_id2titles = {}

for item in hotpot_pred:
    title = [item['predicted_ordered_contexts'][idx]['wikipedia_title'] for idx in range(len(item['predicted_ordered_contexts']))]
    sa_pred_hotpot_id2titles[item['id']] = title

em_tot_hotpot, f1_tot_hotpot = [], []
for item in hotpot_dev_data:
    sf = []
    pred = []
    sp_title_set = []
    id = item['_id']
    hop = 2
    if item['type'] == 'bridge_comparison':
        hop = 4
    # take the top [hop] passages as the predicted relevant passages
    sa_pred_hotpot_id2titles[id] = sa_pred_hotpot_id2titles[id][:hop]
    for sup in item['supporting_facts']:
        sp_title_set.append(sup[0])
    for idx, (title, sts) in enumerate(item['context']):
        if title in sa_pred_hotpot_id2titles[id]:
            pred.append(idx)
        if title in sp_title_set:
            sf.append(idx)
    f1, em = calculate_em_f1(pred, sf)
    em_tot_hotpot.append(em)
    f1_tot_hotpot.append(f1)
print("em:", sum(em_tot_hotpot) / len(em_tot_hotpot), "f1:", sum(f1_tot_hotpot) / len(f1_tot_hotpot))

Thanks for the clarification again @HarshTrivedi !