irthomasthomas / undecidability

13 stars 2 forks source link

SELF-RAG: Learning to Retrieve, Generate and Critique through Self-reflection #778

Open irthomasthomas opened 8 months ago

irthomasthomas commented 8 months ago

SELF-RAG: Learning to Retrieve, Generate and Critique through Self-reflection

This includes the original implementation of SELF-RAG: Learning to Retrieve, Generate and Critique through self-reflection (ICLR 2024, Oral top 1%) by Akari Asai, Zeqiu Wu, Yizhong Wang, Avirup Sil, and Hannaneh Hajishirzi.

Website | 7B Model | 13B Model | Paper | Training data | Twitter summary | Updates

Self-RAG (Figure right) is a new framework to train an arbitrary LM to learn to retrieve, generate, and critique to enhance the factuality and quality of generations, without hurting the versatility of LLMs.

Unlike a widely-adopted Retrieval-Augmented Generation (RAG; Figure left) approach, Self-RAG retrieves on demand (e.g., can retrieve multiple times or completely skip retrieval) given diverse queries, and criticize its own generation from multiple fine-grained aspects by predicting reflection tokens as an integral part of generation. We conduct a segment-wise beam search to select the output that maximizes the utility for diverse preferences.

If you find our code, data, models, or the paper useful, please cite the paper:

@inproceedings{
asai2024selfrag,
author={Asai, Akari and Wu, Zeqiu and Wang, Yizhong and Sil, Avirup and Hajishirzi, Hannaneh},
title={Self-{RAG}: Learning to Retrieve, Generate, and Critique through Self-Reflection},
booktitle={The Twelfth International Conference on Learning Representations},
year={2024},
url={https://openreview.net/forum?id=hSyW5go0v8}
}

Updates

Content

  1. Installation
  2. Quick Start
  3. Retriever setup
  4. Training
  5. Inference
  6. Baselines
  7. FAQ
  8. Contact

Installation

Install dependent Python libraries by running the command below.

pip install -r requirements.txt

Please use the latest version of vllm, as the older version may not enable you to set skip_special_tokens via SamplingParam, which is added by (this PR).

You can also create a conda environment by running the command below.

conda env create -f environment.yml

Quick start

You can download Self-RAG from HuggingFace Hub. For inference, we recommend using vllm as it significantly speeds up inferences.

from vllm import LLM, SamplingParams

model = LLM("selfrag/selfrag_llama2_7b", download_dir="/gscratch/h2lab/akari/model_cache", dtype="half")
sampling_params = SamplingParams(temperature=0.0, top_p=1.0, max_tokens=100, skip_special_tokens=False)

def format_prompt(input, paragraph=None):
  prompt = "### Instruction:\n{0}\n\n### Response:\n".format(input)
  if paragraph is not None:
    prompt += "[Retrieval]<paragraph>{0}</paragraph>".format(paragraph)
  return prompt

query_1 = "Leave odd one out: twitter, instagram, whatsapp."
query_2 = "Can you tell me the difference between llamas and alpacas?"
queries = [query_1, query_2]

# for a query that doesn't require retrieval
preds = model.generate([format_prompt(query) for query in queries], sampling_params)
for pred in preds:
  print("Model prediction: {0}".format(pred.outputs[0].text))

Output:

Model prediction: Twitter, Instagram, and WhatsApp are all social media platforms. [No Retrieval]WhatsApp is the odd one out because it is a messaging app, while Twitter and # Instagram are primarily used for sharing photos and videos.[Utility:5]</s>
Model prediction: Sure![Retrieval]<paragraph><paragraph>

As you can see, Self-RAG starts generating responses without retrieval in the first query when it does not require retrieval. On the other hand, Self-RAG output [Retrieve] tokens for the second, as this question requires more fine-grained factual grounding.

For queries that require factual grounding, you can insert a paragraph. Self-RAG can retrieve and insert paragraphs anytime while generating, and recognizes them as long as they are surrounded by context markup special tokens <paragraph>, </paragraph>.

# for a query that needs factual grounding
prompt = format_prompt("Can you tell me the difference between llamas and alpacas?", "The alpaca (Lama pacos) is a species of South American camelid mammal. It is similar to, and often confused with, the llama. Alpacas are considerably smaller than llamas, and unlike llamas, they were not bred to be working animals, but were bred specifically for their fiber.")
preds = model.generate([prompt], sampling_params)
print([pred.outputs[0].text for pred in preds])
# ['[Relevant]Alpacas are considerably smaller than llamas, and unlike llamas, they were not bred to be working animals, but were bred specifically for their fiber.[Fully supported][Utility:5]</s>']

Self-RAG finds the relevant inserted document and generates answers that are fully supported by the evidence.

Run your evaluation using the online retrieval model

You can also run retrieval on-demand and use it with Self-RAG. As running retrieval over full English Wikipedia requires large RAM and multiple GPUs, we created a subset of Wikipedia, including intro paragraphs of Wikipedia articles only for demo purposes.

First, please download the corpus and embeddings (9GB in total).

git clone git@github.com:AkariAsai/self-rag.git
cd retrieval_lm
bash download_demo_corpus.sh

If the script does not work, you can download the data from google drive or HF dataset. Then, you can run the script under retrieval_lm. We tested the script using on 1 RTX 6000 with 24GB and 100G RAM (but should be runnable with much smaller RAM).

from passage_retrieval import Retriever
retriever = Retriever({})
retriever.setup_retriever_demo("facebook/contriever-msmarco", "enwiki_2020_intro_only/enwiki_2020_dec_intro_only.jsonl", "enwiki_2020_intro_only/enwiki_dec_2020_contriever_intro/*",  n_docs=5, save_or_load_index=False)
retrieved_documents = retriever.search_document_demo(query_3, 5)
prompts = [format_prompt(query_3, doc["title"]+"\n"+ doc["text"]) for doc in retrieved_documents]
preds = model.generate(prompts, sampling_params)
top_doc = retriever.search_document_demo(query_3, 1)[0]
print("Reference: {0}\nModel prediction: {1}".format(top_doc["title"] + "\n" + top_doc["text"], preds[0].outputs[0].text))

Output:

Reference: Overfitting
  In statistics, overfitting is "the production of an analysis that corresponds too closely or exactly to a particular set of data, and may therefore fail to fit additional data or predict future observations reliably". An overfitted model is a statistical model that contains more parameters than can be justified by the data. The essence of overfitting is to have unknowingly extracted some of the residual variation (i.e., the noise) as if that variation represented underlying model structure. Underfitting occurs when a statistical model cannot adequately capture the underlying structure of the data. An under-fitted model is a model where some parameters or terms that would appear in a correctly specified model are
Model prediction: [Relevant]Overfitting occurs when a model has too many parameters relative to the amount of data it has been trained on, leading it to memorize the training data too closely and perform poorly on new, unseen data.[Fully supported][Utility:5]</s>

The retriever system properly retrieves necessary document and generate fully grounded output.

Note that this demo uses a smaller corpus and Self-RAG with the full inference algorithm. For a full evaluation, you either need to set up a retriever or download our retrieved results. Please follow instructions at Inference.

Retriever Setup

By default, we use Contriever as our retrieval component.

Download data

Download preprocessed passage data used in DPR.

cd retrieval_lm
wget https://dl.fbaipublicfiles.com/dpr/wikipedia_split/psgs_w100.tsv.gz

Then, download the generated passages. We use Contriever-MSMARCO

wget https://dl.fbaipublicfiles.com/contriever/embeddings/contriever-msmarco/wikipedia_embeddings.tar

Run retriever

You can run passage retrieval by running the command below.

cd retrieval_lm
python passage_retrieval.py \
    --model_name_or_path facebook/contriever-msmarco --passages psgs_w100.tsv \
    --passages_embeddings "wikipedia_embeddings/*" \
    --data YOUR_INPUT_FILE  \
    --output_dir YOUR_OUTPUT_FILE \
    --n_docs 20

Your input file should be either a json or jsonl. Each instance must contain either question or instruction, which will be used as a query during retrieval.

Generate embeddings for your own data

You can generate embeddings for your own data by running the following command. (The script is adapted from the Contriever repository.) Note that generating embeddings from a large-scale corpus (>10M docs) can take time, and we recommend running it on multiple GPUs.

cd retrieval_lm
for i in {0..3}; do
  export CUDA_VISIBLE_DEVICES=${i}
  python generate_passage_embeddings.py  --model_name_or_path facebook/contriever-msmarco \
  --output_dir YOUR_OUTPUT_DIR \
  --passages YOUR_PASSAGE_DATA --shard_id ${i}  --num_shards 4 > ./log/nohup.my_embeddings.${i} 2>&1 &

Training

Self-RAG trains two models, Critic and Generator, both of which expand token vocabularies with reflection tokens and are trained with the standard next token prediction objective.

Alternatively, you can download our training data consisting of 150K instances here.

Collect reflection tokens

We collect training data from GPT-4. The scripts to call GPT-4 for each special token type are available at data_creation/critic.

Alternatively, you can download our training data at here.

Critic training

Once you create or download training data, run the command below to fine-tune Llama2-7B on critic training.

cd data_creation
torchrun --nproc_per_node=2 \
  --master_port=2568 train_special_tokens.py \
  --model_name_or_path meta-llama/Llama-2-7b-hf \
  --data_path PATH_TO_TRAIN_DATA_FILE \
  --bf16  True \
  --output_dir PATH_TO_CRITIC_MODEL \
  --num_train_epochs 3  \
  --per_device_train_batch_size 1 --per_device_eval_batch_size 1 \
  --gradient_accumulation_steps 8 \
  --evaluation_strategy "no" \
  --save_strategy "steps" \
  --save_steps 300 \
  --save_total_limit 1 \
  --learning_rate 2e-5 \
  --weight_decay 0. \
  --warmup_ratio 0.01 \
  --lr_scheduler_type "cosine" \
  --logging_steps 10 \
  --fsdp "full_shard auto_wrap"

Generator Data Creation

The code to create Generator training data is under generator_data_creation. See the instructions at README.md.

Alternatively, you can download our training data at HuggingFace dataset or here

Generator Training

For generator training, we use DeepSpeed to make training more efficient. You can run training by running the script below, after setting the training data path.

cd retrieval_lm
bash script_finetune_7b.sh

For 13B model training, use training_13b. We use 8 A100 with 40 GRAM for 7B model training and 4 a100 with 80 GB GRAM for 13B training. 7B should fit 1-2 A100 although training can be slow.

Inference

For the task evaluation conducted in the paper, please download the data here.

Each file already comes with retrieved documents, so if you don't want to run a retriever as a part of inference, you can simply load the retrieved docs at contexts.

Below, we describe Self-RAG and baselines.

Short-form (PubHealth, ARC-Challenge, TriviaQA, PopQA)

As we typically retrieve only once for a short-form generation task, we provide an easy-to-run evaluation script that leverages pre-given documents retrieved by Contriever offline. See the individual command below.

Question Answering


python run_short

#### Suggested labels
#### {'label-name': 'LM-Training', 'label-description': 'Information related to training language models like Self-RAG.', 'confidence': 57.48}
irthomasthomas commented 8 months ago

Related content

768 similarity score: 0.91

750 similarity score: 0.9

315 similarity score: 0.9

762 similarity score: 0.9

665 similarity score: 0.9

386 similarity score: 0.89