Official repo for xRAG: Extreme Context Compression for Retrieval-augmented Generation with One Token
Refer to Dockerfile
for required packages
Configure wandb
and accelerate
wandb login
accelerate config
HuggingFace | Model | Backbone | Download |
---|---|---|---|
xRAG-7b | mistralai/Mistral-7B-Instruct-v0.2 | 🤗 Hugging Face | |
xRAG-MoE | mistralai/Mixtral-8x7B-Instruct-v0.1 | 🤗 Hugging Face |
We provide a tutorial for xRAG in tutorial.ipynb
. Check it out!
prepare_data.ipynb
Training scripts in scripts/
, for example, to train a Mistral-7b with SFR:
accelerate launch \
--mixed_precision bf16 \
--num_machines 1 \
--num_processes 8 \
--main_process_port 29666 \
-m \
src.language_modeling.train \
--config config/language_modeling/pretrain.yaml \
The evaluation code is in src/eval
. For example, to evaluate on TriviaQA:
without retrieval augmentation:
CUDA_VISIBLE_DEVICES=0 python -m src.eval.run_eval \
--data triviaqa \
--model_name_or_path Hannibal046/xrag-7b
with retrieval augmentation:
CUDA_VISIBLE_DEVICES=0 python -m src.eval.run_eval \
--data triviaqa \
--model_name_or_path Hannibal046/xrag-7b \
--use_rag
with xRAG:
CUDA_VISIBLE_DEVICES=0 python -m src.eval.run_eval \
--data triviaqa \
--model_name_or_path Hannibal046/xrag-7b \
--retriever_name_or_path Salesforce/SFR-Embedding-Mistral \
--use_rag
To benchmark xRAG, we provide the code in src/language_modeling/profiler.py
.
python -m src.language_modeling.profiler --instruction_length 54 --generation_length 30 --dataset triviaqa --use_xrag
python -m src.language_modeling.profiler --instruction_length 54 --generation_length 30 --dataset triviaqa