NLPJCL / RAG-Retrieval

Unify Efficient Fine-tuning of RAG Retrieval, including Embedding, ColBERT,Cross Encoder
MIT License
441 stars 38 forks source link

Reranker 模型微调时出现错误如下: raise Exception("Could not find the transformer layer class to wrap in the model.") #29

Closed huangjun11 closed 2 months ago

huangjun11 commented 2 months ago

Traceback (most recent call last): File "/data/junhuang/rag_model/RAG-Retrieval/rag_retrieval/train/reranker/train_reranker.py", line 175, in main() File "/data/junhuang/rag_model/RAG-Retrieval/rag_retrieval/train/reranker/train_reranker.py", line 105, in main model = accelerator.prepare(model) File "/home/root123/.local/lib/python3.10/site-packages/accelerate/accelerator.py", line 1304, in prepare result = tuple( File "/home/root123/.local/lib/python3.10/site-packages/accelerate/accelerator.py", line 1305, in self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement) File "/home/root123/.local/lib/python3.10/site-packages/accelerate/accelerator.py", line 1181, in _prepare_one return self.prepare_model(obj, device_placement=device_placement) File "/home/root123/.local/lib/python3.10/site-packages/accelerate/accelerator.py", line 1461, in prepare_model self.state.fsdp_plugin.set_auto_wrap_policy(model) File "/home/root123/.local/lib/python3.10/site-packages/accelerate/utils/dataclasses.py", line 1367, in set_auto_wrap_policy raise Exception("Could not find the transformer layer class to wrap in the model.")

训练命令如下:
 CUDA_VISIBLE_DEVICES="4" accelerate launch --config_file ../../../config/default_fsdp.yaml train_reranker.py  \

--model_name_or_path "/data/junhuang/rag_model/FlagEmbedding/model/bge-reranker-base" \ --dataset "/data/junhuang/rag_model/RAG-Retrieval/qa_train.jsonl" \ --output_dir "./output/t2ranking_100_example" \ --loss_type "classfication" \ --batch_size 16 \ --lr 5e-5 \ --epochs 10 \ --num_labels 1 \ --log_with 'wandb' \ --save_on_epoch_end 1 \ --warmup_proportion 0.1 \ --gradient_accumulation_steps 1 \ --max_len 512

NLPJCL commented 2 months ago

参考https://github.com/NLPJCL/RAG-Retrieval/issues/5