AkariAsai / self-rag

This includes the original implementation of SELF-RAG: Learning to Retrieve, Generate and Critique through self-reflection by Akari Asai, Zeqiu Wu, Yizhong Wang, Avirup Sil, and Hannaneh Hajishirzi.
https://selfrag.github.io/
MIT License
1.63k stars 141 forks source link

Reproducing the TriviaQA numbers #20

Closed jiangllan closed 7 months ago

jiangllan commented 7 months ago

Hi @AkariAsai , great work and thank you very much for this repo!

I try to reproduce the baseline results on TriviaQA task, using Llama-2-7b downloaded from Huggingface. But the results I get are quite different from those reported in the paper:

The command I used for retrieval-augmented experiment:

python run_baseline_lm.py \ 
 --model_name PATH_TO_LLAMA2-7B-HF \
 --input_file eval_data/triviaqa_test_w_gs.jsonl \
 --max_new_tokens 100 \
 --metric accuracy \
 --task qa \
 --dtype half \
 --mode retrieval \
 --batch_size 5 \
 --prompt_name "prompt_no_input_retrieval"

The gap is so huge that I'm wondering if I got some details wrong.

Thanks in advance :)

jiangllan commented 7 months ago

Since the PROMPT_DICT in utils.py does not include key prompt_no_input_retrieval, I implemented it following run_short_form.py , which is

PROMPT_DICT = {
    "prompt_input": (
        "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
    ),
    "prompt_no_input": (
        "### Instruction:\n{instruction}\n\n### Response:\n"
    ),
    "prompt_no_input_retrieval": (
        "### Instruction:\n{instruction}\n\n### Response:\n[Retrieval]<paragraph>{paragraph}</paragraph>"
    )
}
AkariAsai commented 7 months ago

Hi I just wanted to double check but did you use the meta-llama/Llama-2-7b-hf or meta-llama/Llama-2-7b? Sorry for the error of the missing prompt_no_input_retrieval. I mistakenly dropped when I was refactoring the code base. For baselines, we used different RAG prompts as we found some models perform much better when we locate the paragraphs before instructions. I'll update the script. On the other hand, the number seems to be too low and I am not sure why this happens... I can rerun the evaluations this week.

jiangllan commented 7 months ago

Hi, Thanks for your response!

I used meta-llama/Llama-2-7b(converted into the Hugging Face Transformers format using the conversion script).

By the way, I also experimented with the meta-llama/Llama-2-7b-chat model, and set the prompts according to the original Llama2 paper like:

"my_prompt_no_input": (
    "Answer these questions:\nQ: {instruction}\n\nA:\n"
),
"my_prompt_no_input_retrieval": (
    "Answer these questions:\nQ: {instruction}\nReferences: {paragraph}\nA:"
)

This approach yielded more reasonable results.

AkariAsai commented 7 months ago

I updated the prompt dic here. I'm currently rerunning evaluations on TriviaQA and let you know once it's done and I get the final number.

PROMPT_DICT = {
    "prompt_input": (
        "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
    ),
    "prompt_no_input": (
        "### Instruction:\n{instruction}\n\n### Response:\n"
    ),
    "prompt_no_input_retrieval": (
        "Below is an instruction that describes a task. "
        "Write a response that appropriately completes the request.\n\n"
        "### Paragraph:\n{paragraph}\n\n### Instruction:\n{instruction}\n\n### Response:"
    ),
... 
jiangllan commented 7 months ago

Got it! I will use it to rerun my experiment on TQA today. Let's see if I can obtain the same final number as you did = )

AkariAsai commented 7 months ago

Oh one thing is you should use match as an evaluation metric for OpenQA task! accuracy is used for the task where we only have a single answer label, while match checkes whether any of the gold answer string is included in the model generation, as specified in the paper.

I rerun the evaluation with the most recent commit and got 43.4 on TriviaQA by the command below:

python run_baseline_lm.py \
--model_name meta-llama/Llama-2-7b-hf \
--input_file eval_data/triviaqa_test.jsonl \
--max_new_tokens 100 \
--metric match \
--result_fp triviaqa_llama2_7b-with_retrieval_full_test.jsonl \
--task qa \
--mode retrieval --prompt_name "prompt_no_input_retrieval" \
--download_dir /gscratch/h2lab/akari/model_cache

Let me know if you still see some issue.

jiangllan commented 7 months ago

Hi! I successfully reproduced similar numbers for the TQA task. Thank you so much for your assistance! I'm now ready to close this issue.