shibing624 / MedicalGPT

MedicalGPT: Training Your Own Medical GPT Model with ChatGPT Training Pipeline. 训练医疗大模型,实现了包括增量预训练(PT)、有监督微调(SFT)、RLHF、DPO、ORPO。
Apache License 2.0
2.94k stars 451 forks source link

DPO训练,报错:“IndexError: Invalid key: 0 is out of bounds for size 0” #375

Closed dage0127 closed 1 month ago

dage0127 commented 1 month ago

下面是执行过程和问题说明,请帮忙看看,谢谢。

一、问题说明:

  1. 使用: “run_training_dpo_pipeline.ipynb”的副本,在google colab运行
  2. 替换模型: --model_type auto \ --model_name_or_path Qwen/Qwen1.5-0.5B-Chat \
  3. 过程:PT,SFT训练正常,merge正常。进行DPO步骤时失败。
  4. 错误:“IndexError: Invalid key: 0 is out of bounds for size 0”
  5. 检查:好像是报Eval数据没找到,但从日志看,validation有数据 validation: Dataset({ features: ['system', 'history', 'question', 'response_chosen', 'response_rejected'], num_rows: 500 })

2024-05-16 01:37:11.846 | DEBUG | main:main:387 - Num eval_samples: 0 2024-05-16 01:37:11.846 | DEBUG | main:main:388 - First eval example: Traceback (most recent call last): File "/content/MedicalGPT/dpo_training.py", line 523, in main() File "/content/MedicalGPT/dpo_training.py", line 389, in main first_example = eval_dataset[0] File "/usr/local/lib/python3.10/dist-packages/datasets/arrow_dataset.py", line 2861, in getitem return self._getitem(key) File "/usr/local/lib/python3.10/dist-packages/datasets/arrow_dataset.py", line 2845, in _getitem pa_subtable = query_table(self._data, key, indices=self._indices) File "/usr/local/lib/python3.10/dist-packages/datasets/formatting/formatting.py", line 587, in query_table _check_valid_index_key(key, size) File "/usr/local/lib/python3.10/dist-packages/datasets/formatting/formatting.py", line 527, in _check_valid_index_key raise IndexError(f"Invalid key: {key} is out of bounds for size {size}") IndexError: Invalid key: 0 is out of bounds for size 0

二、下面是详细信息:

1、执行命令: !python dpo_training.py \ --model_type auto \ --model_name_or_path merged-sft \ --train_file_dir ./data/reward \ --validation_file_dir ./data/reward \ --per_device_train_batch_size 3 \ --per_device_eval_batch_size 1 \ --do_train \ --do_eval \ --use_peft True \ --max_train_samples 1000 \ --max_eval_samples 10 \ --max_steps 100 \ --eval_steps 10 \ --save_steps 50 \ --max_source_length 128 \ --max_target_length 128 \ --output_dir outputs-dpo-v1 \ --target_modules all \ --lora_rank 8 \ --lora_alpha 16 \ --lora_dropout 0.05 \ --torch_dtype float16 \ --fp16 True \ --device_map auto \ --report_to tensorboard \ --remove_unused_columns False \ --gradient_checkpointing True \ --cache_dir ./cache

2、详细错误: /usr/local/lib/python3.10/dist-packages/transformers/deepspeed.py:23: FutureWarning: transformers.deepspeed module is deprecated and will be removed in a future version. Please import deepspeed modules directly from transformers.integrations warnings.warn( 2024-05-16 01:37:09.207085: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2024-05-16 01:37:09.207134: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2024-05-16 01:37:09.208593: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered 2024-05-16 01:37:10.467576: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT 2024-05-16 01:37:10.975 | INFO | main:main:218 - Parse args: ScriptArguments(model_type='auto', model_name_or_path='merged-sft', tokenizer_name_or_path=None, load_in_8bit=False, load_in_4bit=False, cache_dir='./cache', use_fast_tokenizer=False, torch_dtype='float16', device_map='auto', trust_remote_code=True, dataset_name=None, dataset_config_name=None, train_file_dir='./data/reward', validation_file_dir='./data/reward', template_name='vicuna', per_device_train_batch_size=3, per_device_eval_batch_size=1, max_source_length=128, max_target_length=128, min_target_length=4, max_train_samples=1000, max_eval_samples=10, overwrite_cache=False, validation_split_percentage=1, preprocessing_num_workers=4, use_peft=True, qlora=False, target_modules='all', lora_rank=8, lora_dropout=0.05, lora_alpha=16.0, peft_path=None, do_train=True, do_eval=True, beta=0.1, learning_rate=0.0005, lr_scheduler_type='cosine', warmup_steps=100, weight_decay=0.05, optim='adamw_hf', fp16=True, bf16=False, gradient_checkpointing=True, gradient_accumulation_steps=4, save_steps=50, eval_steps=10, logging_steps=1, output_dir='outputs-dpo-v1', max_steps=100, eval_strategy='steps', remove_unused_columns=False, report_to='tensorboard') Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained. 2024-05-16 01:37:11.267 | INFO | main:main:241 - Add bos_token: <|im_end|>, bos_token_id: 151645 2024-05-16 01:37:11.267 | DEBUG | main:main:248 - Tokenizer: Qwen2Tokenizer(name_or_path='merged-sft', vocab_size=151643, model_max_length=32768, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|im_end|>', 'eos_token': '<|im_end|>', 'pad_token': '<|endoftext|>', 'additional_special_tokens': ['<|im_start|>', '<|im_end|>']}, clean_up_tokenization_spaces=False), added_tokens_decoder={ 151643: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), 151644: AddedToken("<|im_start|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), 151645: AddedToken("<|im_end|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), } 2024-05-16 01:37:11.268 | INFO | main:main:276 - train files: ./data/reward/dpo_zh_500.jsonl 2024-05-16 01:37:11.268 | INFO | main:main:281 - eval files: ./data/reward/dpo_zh_500.jsonl 2024-05-16 01:37:11.455 | INFO | main:main:302 - Raw datasets: DatasetDict({ train: Dataset({ features: ['system', 'history', 'question', 'response_chosen', 'response_rejected'], num_rows: 500 }) validation: Dataset({ features: ['system', 'history', 'question', 'response_chosen', 'response_rejected'], num_rows: 500 }) }) 2024-05-16 01:37:11.457 | DEBUG | main:main:344 - Example train_dataset[0]: {'system': '', 'history': [], 'question': '20个关于新鲜果汁菜单的口号,适用于一家名为"Dishes"的餐厅', 'response_chosen': '这里是一个名为“Dishes”的餐厅的20个口号,突出了其新鲜果汁菜单:\n\n1. “品尝Dishes新鲜果汁,感受不同!”\n2. “新鲜榨取,直达您的餐桌 - Dishes果汁纯享!”\n3. “用一杯清新的Dishes果汁开启您的一天!”\n4. “每一口Dishes新鲜果汁都是大自然的味道!”\n5. “Dishes:新鲜果汁是焦点!”\n6. “满足您的口腹之欲,享用Dishes口水直流的农场果汁!”\n7. “新鲜果汁,新鲜味道,新鲜菜肴 - 这是Dishes的承诺!”\n8. “用Dishes营养果汁获得每日所需的维生素和矿物质!”\n9. “解渴滋养心灵,品尝Dishes美味果汁!”\n10. “Dishes:每一口都是完美的味道!”\n11. “新鲜制作,完美平衡 - Dishes果汁是感官的享受!”\n12. “从农场到餐桌,Dishes果汁充满天然好处!”\n13. “踏入Dishes,品尝我们新鲜果汁的甜蜜!”\n14. “用Dishes 100%新鲜水果果汁呵护您的身体!”\n15. “Dishes:每一杯果汁都是用激情和关怀精心制作!”\n16. “沉醉于Dishes新鲜榨取果汁的健康热情!”\n17. “用Dishes招牌果汁混合物提升您的用餐体验!”\n18. “健康饮品的清新转变 - Dishes果汁必尝!”\n19. “加入Dishes的新鲜果汁革命 - 您的味蕾会感激您!”\n20. “Dishes:果汁永远新鲜,味道永远美味!”', 'response_rejected': '1. "与菜肴一起品尝新鲜!"\n2. "菜肴:新鲜果汁,新的开始!"\n3. "用菜肴的新鲜混合果汁提神!"\n4. "菜肴,新鲜就是最好的"\n5. "在菜肴庆祝新鲜"\n6. "与菜肴的新鲜果汁为健康干杯"\n7. "在菜肴发现新鲜的魔力"\n8. "品尝菜肴的新鲜果汁,感受不同"\n9. "在菜肴解锁新鲜"\n10. "用菜肴的新鲜果汁迎接新的一天"\n11. "在菜肴,每天都有新鲜"\n12. "用菜肴的新鲜果汁获得能量"\n13. "在菜肴为生活喝果汁"\n14. "拥抱健康,享受菜肴的新鲜果汁"\n15. "菜肴:新鲜与美味的交汇处"\n16. "在菜肴体验新鲜的力量"\n17. "菜肴:把健康送到你家门口"\n18. "像微风一样清新,菜肴的果汁"\n19. "生命太短暂,只为菜肴的新鲜果汁"\n20. "菜肴:新鲜始终是你一天的首选"'} /usr/local/lib/python3.10/dist-packages/multiprocess/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock. self.pid = os.fork() Running tokenizer on dataset (num_proc=4): 100% 500/500 [00:00<00:00, 2236.74 examples/s] Filter: 100% 500/500 [00:00<00:00, 41085.99 examples/s] 2024-05-16 01:37:11.826 | DEBUG | main:main:357 - Num train_samples: 15 2024-05-16 01:37:11.826 | DEBUG | main:main:358 - First train example: 2024-05-16 01:37:11.826 | DEBUG | main:main:360 - prompt: A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.USER: 奥巴马的丈夫是谁? ASSISTANT: 2024-05-16 01:37:11.826 | DEBUG | main:main:361 - chosen: 奥巴马不是和男性结婚的。他是和前第一夫人米歇尔·奥巴马结婚的。 2024-05-16 01:37:11.827 | DEBUG | main:main:362 - rejected: 巴拉克·奥巴马的丈夫是巴拉克·奥巴马二世。 2024-05-16 01:37:11.828 | DEBUG | main:main:374 - Example eval_dataset[0]: {'system': '', 'history': [], 'question': '20个关于新鲜果汁菜单的口号,适用于一家名为"Dishes"的餐厅', 'response_chosen': '这里是一个名为“Dishes”的餐厅的20个口号,突出了其新鲜果汁菜单:\n\n1. “品尝Dishes新鲜果汁,感受不同!”\n2. “新鲜榨取,直达您的餐桌 - Dishes果汁纯享!”\n3. “用一杯清新的Dishes果汁开启您的一天!”\n4. “每一口Dishes新鲜果汁都是大自然的味道!”\n5. “Dishes:新鲜果汁是焦点!”\n6. “满足您的口腹之欲,享用Dishes口水直流的农场果汁!”\n7. “新鲜果汁,新鲜味道,新鲜菜肴 - 这是Dishes的承诺!”\n8. “用Dishes营养果汁获得每日所需的维生素和矿物质!”\n9. “解渴滋养心灵,品尝Dishes美味果汁!”\n10. “Dishes:每一口都是完美的味道!”\n11. “新鲜制作,完美平衡 - Dishes果汁是感官的享受!”\n12. “从农场到餐桌,Dishes果汁充满天然好处!”\n13. “踏入Dishes,品尝我们新鲜果汁的甜蜜!”\n14. “用Dishes 100%新鲜水果果汁呵护您的身体!”\n15. “Dishes:每一杯果汁都是用激情和关怀精心制作!”\n16. “沉醉于Dishes新鲜榨取果汁的健康热情!”\n17. “用Dishes招牌果汁混合物提升您的用餐体验!”\n18. “健康饮品的清新转变 - Dishes果汁必尝!”\n19. “加入Dishes的新鲜果汁革命 - 您的味蕾会感激您!”\n20. “Dishes:果汁永远新鲜,味道永远美味!”', 'response_rejected': '1. "与菜肴一起品尝新鲜!"\n2. "菜肴:新鲜果汁,新的开始!"\n3. "用菜肴的新鲜混合果汁提神!"\n4. "菜肴,新鲜就是最好的"\n5. "在菜肴庆祝新鲜"\n6. "与菜肴的新鲜果汁为健康干杯"\n7. "在菜肴发现新鲜的魔力"\n8. "品尝菜肴的新鲜果汁,感受不同"\n9. "在菜肴解锁新鲜"\n10. "用菜肴的新鲜果汁迎接新的一天"\n11. "在菜肴,每天都有新鲜"\n12. "用菜肴的新鲜果汁获得能量"\n13. "在菜肴为生活喝果汁"\n14. "拥抱健康,享受菜肴的新鲜果汁"\n15. "菜肴:新鲜与美味的交汇处"\n16. "在菜肴体验新鲜的力量"\n17. "菜肴:把健康送到你家门口"\n18. "像微风一样清新,菜肴的果汁"\n19. "生命太短暂,只为菜肴的新鲜果汁"\n20. "菜肴:新鲜始终是你一天的首选"'} 2024-05-16 01:37:11.846 | DEBUG | main:main:387 - Num eval_samples: 0 2024-05-16 01:37:11.846 | DEBUG | main:main:388 - First eval example: Traceback (most recent call last): File "/content/MedicalGPT/dpo_training.py", line 523, in main() File "/content/MedicalGPT/dpo_training.py", line 389, in main first_example = eval_dataset[0] File "/usr/local/lib/python3.10/dist-packages/datasets/arrow_dataset.py", line 2861, in getitem return self._getitem(key) File "/usr/local/lib/python3.10/dist-packages/datasets/arrow_dataset.py", line 2845, in _getitem pa_subtable = query_table(self._data, key, indices=self._indices) File "/usr/local/lib/python3.10/dist-packages/datasets/formatting/formatting.py", line 587, in query_table _check_valid_index_key(key, size) File "/usr/local/lib/python3.10/dist-packages/datasets/formatting/formatting.py", line 527, in _check_valid_index_key raise IndexError(f"Invalid key: {key} is out of bounds for size {size}") IndexError: Invalid key: 0 is out of bounds for size 0

shibing624 commented 1 month ago

valid 数据不符合长度要求,都被过滤了(train也只保留了15条),要么改下训练集长度要求,要么把测试集也改为能用的。

1.--max_eval_samples 1000

  1. --max_source_length 512 --max_target_length 512
dage0127 commented 1 month ago

果然好使。多谢多谢!