AkariAsai / self-rag

This includes the original implementation of SELF-RAG: Learning to Retrieve, Generate and Critique through self-reflection by Akari Asai, Zeqiu Wu, Yizhong Wang, Avirup Sil, and Hannaneh Hajishirzi.
https://selfrag.github.io/
MIT License
1.76k stars 162 forks source link

accuracy metric #72

Open zhuzihan728 opened 5 months ago

zhuzihan728 commented 5 months ago

The accuracy in metrics.py is defined as

def accuracy(preds, labels):
    match_count = 0
    for pred, label in zip(preds, labels):
        target = label[0]
        if pred == target:
            match_count += 1

    return 100 * (match_count / len(preds))

While in run_short_form.py, acc is calculated per data instance:

if args.metric == "accuracy":
            metric_result = accuracy(pred, row["output"])

where pred is some string, and row["output"] is neither present in any short-form dataset, nor defined in your code.

leeds1219 commented 4 months ago

changed "output" to "answers" and it kind of fixed the problem but 0it [00:00, ?it/s]

...

Processed prompts: 0%| | 0/10 [00:00<?, ?it/s]

Processed prompts: 10%|█ | 1/10 [00:00<00:03, 2.59it/s]

Processed prompts: 70%|███████ | 7/10 [00:00<00:00, 13.11it/s] Processed prompts: 100%|██████████| 10/10 [00:00<00:00, 15.28it/s]

19it [00:53, 2.81s/it] Traceback (most recent call last): File "run_short_form.py", line 378, in main() File "run_short_form.py", line 349, in main metric_result = accuracy(pred, row["answers"]) File "/workspace/rag/rag/self-rag/retrieval_lm/metrics.py", line 21, in accuracy target = label[0] IndexError: string index out of range

i think the code isnt complete yet...

leeds1219 commented 4 months ago
preds = []
prompts = []
golds = []
metric_results = []
scores = []
all_results = []
count = 0
for i, row in tqdm(enumerate(input_data)):
    results = {}
    prompt = PROMPT_DICT["prompt_no_input"].format_map(row)
    _, evidences = process_data_evidences(row, top_n=args.ndocs)
    pred, results, do_retrieve = generate(
        prompt, evidences, max_new_tokens=args.max_new_tokens,)
    if type(pred) is str and pred[0] == "#" or pred[0] == ":":
        pred = pred[1:]
    prompts.append(prompt)
    preds.append(pred)
    all_results.append(results)
    if do_retrieve is True:
        count += 1
    if "answers" not in row and "answer" in row:
        row["answers"] = [row["answer"]] if type(
            row["answer"]) is str else row["answer"]
    ######################################################################
    # 2024-05-22 fixed index outof range error        
    row["answers"] = [answer for answer in row["answers"] if answer != ""] 
    ######################################################################
    if args.metric == "accuracy":

#############################################################################################

2024-05-22 fixed key error "output" doesnt exist

        # metric_result = accuracy(pred, row["output"])
        metric_result = accuracy(pred, row["answers"])

############################################################################################## elif args.metric == "match": if "SUPPORTS" in pred: pred = "true" elif "REFUTES" in pred: pred = "false" metric_result = match(pred, row["answers"]) else: raise NotImplementedError

modified the code and works

zhuzihan728 commented 4 months ago
preds = []
prompts = []
golds = []
metric_results = []
scores = []
all_results = []
count = 0
for i, row in tqdm(enumerate(input_data)):
    results = {}
    prompt = PROMPT_DICT["prompt_no_input"].format_map(row)
    _, evidences = process_data_evidences(row, top_n=args.ndocs)
    pred, results, do_retrieve = generate(
        prompt, evidences, max_new_tokens=args.max_new_tokens,)
    if type(pred) is str and pred[0] == "#" or pred[0] == ":":
        pred = pred[1:]
    prompts.append(prompt)
    preds.append(pred)
    all_results.append(results)
    if do_retrieve is True:
        count += 1
    if "answers" not in row and "answer" in row:
        row["answers"] = [row["answer"]] if type(
            row["answer"]) is str else row["answer"]
    ######################################################################
    # 2024-05-22 fixed index outof range error        
    row["answers"] = [answer for answer in row["answers"] if answer != ""] 
    ######################################################################
    if args.metric == "accuracy":

############################################################################################# # 2024-05-22 fixed key error "output" doesnt exist # metric_result = accuracy(pred, row["output"]) metric_result = accuracy(pred, row["answers"]) ############################################################################################## elif args.metric == "match": if "SUPPORTS" in pred: pred = "true" elif "REFUTES" in pred: pred = "false" metric_result = match(pred, row["answers"]) else: raise NotImplementedError

modified the code and works

Thx for the reply :D I see now that the accuracy calculation is to check if the first letter of the prediction matches with the first letter of a gold answer. which only makes sense if the gold answer list is of length one cuz they are zipping a string with a list, and if it is for multiple-choice datasets?

And as you pointed out the list out of index error, the author seems mistakenly put empty strings ("") in some gold answer lists in the eval_data they provide in this link. Not sure why, but this definitely doesn't look right and only lifts up the final metric score if using the match method in metrics.py for the metric calculation. @AkariAsai