RUC-NLPIR / FlashRAG

⚡FlashRAG: A Python Toolkit for Efficient RAG Research
https://arxiv.org/abs/2405.13576
MIT License
890 stars 69 forks source link

复现实验所用的seed #44

Closed WenliangZhoushan closed 4 days ago

WenliangZhoushan commented 5 days ago

请问你们做实验时默认seed就是basic_config下的2024吗?我大约每个实验做出来都比paper里的结果低3-4分。

ignorejjj commented 5 days ago

seed是2024。具体的实验设置是什么?

WenliangZhoushan commented 5 days ago

seed是2024。具体的实验设置是什么?

config_dict = { 'data_dir': '/home/v-wenlzheng/FlashRAG/data/FlashRAG_datasets', 'test_sample_num': 1000, 'index_path': '/home/v-wenlzheng/FlashRAG/data/FlashRAG_datasets/retrieval-corpus/e5_flat_inner.index', 'corpus_path': '/home/v-wenlzheng/FlashRAG/data/FlashRAG_datasets/retrieval-corpus/wiki18_100w.jsonl', 'framework': 'vllm', 'model2path': {'e5': 'intfloat/e5-base-v2', 'llama3-8B-instruct': 'meta-llama/Meta-Llama-3-8B-Instruct'}, 'generator_model': 'llama3-8B-instruct', 'retrieval_method': 'e5', 'metrics': ['em'], 'retrieval_topk': 5, 'save_intermediate_data': True }

这是我的nq的config_dict, 我跑了naive和standarg,分别获得了18.7和33.5分。其他没改变。如果跑多轮实验,seed不变情况下是不是实验结果都应该相同?谢谢

WenliangZhoushan commented 5 days ago

我刚刚又做了下实验,在seed是2024情况下,三轮naive nq获得了18.7,17.1和17.4的分数。看来seed并不能保持结论完全一致

ignorejjj commented 5 days ago

你用的是什么代码? 能提供一下实验文件夹里面的yaml文件的内容吗?

WenliangZhoushan commented 5 days ago

这是代码

config_dict = {
                'data_dir': '/home/v-wenlzheng/FlashRAG/data/FlashRAG_datasets',
                'test_sample_num': 1000,
                'index_path': '/home/v-wenlzheng/FlashRAG/data/FlashRAG_datasets/retrieval-corpus/e5_flat_inner.index',
                'corpus_path': '/home/v-wenlzheng/FlashRAG/data/FlashRAG_datasets/retrieval-corpus/wiki18_100w.jsonl',
                'framework': 'vllm',
                'model2path': {'e5': 'intfloat/e5-base-v2', 'llama3-8B-instruct': 'meta-llama/Meta-Llama-3-8B-Instruct'},
                'generator_model': 'llama3-8B-instruct',
                'retrieval_method': 'e5',
                'metrics': ['em'],
                'retrieval_topk': 5,
                'save_note': 'standard_rag_3_rounds',
                'save_intermediate_data': True
            }

config = Config(config_dict = config_dict)

all_split = get_dataset(config)
test_data = all_split['test']
prompt_templete = PromptTemplate(
    config,
    system_prompt = "Answer the question based on the given document. \
                    Only give me the answer and do not output any other words. \
                    \nThe following are given documents.\n\n{reference}",
    user_prompt = "Question: {question}\n"
)
pipeline = SequentialPipeline(config, prompt_template=prompt_templete)

for _ in range(3):
    output_dataset = pipeline.naive_run(test_data, do_eval=True)
    # output_dataset = pipeline.run(test_data, do_eval=True)

这是output文件夹下的config.yml

corpus_path: /home/v-wenlzheng/FlashRAG/data/FlashRAG_datasets/retrieval-corpus/wiki18_100w.jsonl
data_dir: /home/v-wenlzheng/FlashRAG/data/FlashRAG_datasets
dataset_name: nq
dataset_path: /home/v-wenlzheng/FlashRAG/data/FlashRAG_datasets/nq
device: !!python/object/apply:torch.device
- cuda
faiss_gpu: false
framework: vllm
generation_params:
  max_tokens: 32
generator_batch_size: 4
generator_max_input_len: 1024
generator_model: llama3-8B-instruct
generator_model_path: meta-llama/Meta-Llama-3-8B-Instruct
gpu_id: 0,1,2,3
gpu_memory_utilization: 0.85
index_path: /home/v-wenlzheng/FlashRAG/data/FlashRAG_datasets/retrieval-corpus/e5_flat_inner.index
method2index:
  bm25: null
  contriever: null
  e5: null
metric_setting:
  retrieval_recall_topk: 5
  tokenizer_name: gpt-4
metrics:
- em
model2path:
  bge: intfloat/e5-base-v2
  contriever: facebook/contriever
  e5: intfloat/e5-base-v2
  llama2-13B: meta-llama/Llama-2-13b-hf
  llama2-13B-chat: meta-llama/Llama-2-13b-chat-hf
  llama2-7B: meta-llama/Llama-2-7b-hf
  llama2-7B-chat: meta-llama/Llama-2-7b-chat-hf
  llama3-8B-instruct: meta-llama/Meta-Llama-3-8B-Instruct
model2pooling:
  bge: cls
  contriever: mean
  dpr: cls
  e5: mean
  jina: mean
openai_setting:
  api_key: null
  base_url: null
random_sample: false
rerank_batch_size: 256
rerank_max_length: 512
rerank_model_name: null
rerank_model_path: null
rerank_pooling_method: null
rerank_topk: 5
rerank_use_fp16: true
retrieval_batch_size: 256
retrieval_cache_path: null
retrieval_method: e5
retrieval_model_path: intfloat/e5-base-v2
retrieval_pooling_method: mean
retrieval_query_max_length: 128
retrieval_topk: 5
retrieval_use_fp16: true
save_dir: output/nq_2024_07_01_06_55_naive_3_rounds
save_intermediate_data: true
save_metric_score: true
save_note: naive_3_rounds
save_retrieval_cache: true
seed: 2024
split:
- test
test_sample_num: 1000
use_fid: false
use_reranker: false
use_retrieval_cache: false
use_sentence_transformer: false
ignorejjj commented 5 days ago

尝试在config_dict里面加上下面的设置:

  1. generation_params: {'do_sample':False}
  2. generator_max_input_len: 4096
WenliangZhoushan commented 5 days ago

尝试在config_dict里面加上下面的设置:

  1. generation_params: {'do_sample':False}
  2. generator_max_input_len: 4096

好的谢谢!我试试