taishan1994 / Llama3.1-Finetuning

对llama3进行全参微调、lora微调以及qlora微调。
Apache License 2.0
149 stars 15 forks source link

assert len(input_id) == len(target) AssertionError #3

Open CarlChang39 opened 6 months ago

CarlChang39 commented 6 months ago

您好,我在执行qlora微调复现时遇到这个问题,报错信息是: Traceback (most recent call last): File "../finetune_llama3.py", line 452, in train() File "../finetune_llama3.py", line 445, in train trainer.train() File "/home/nlpir/miniconda3/envs/cjy_llama/lib/python3.8/site-packages/transformers/trainer.py", line 1624, in train return inner_training_loop( File "/home/nlpir/miniconda3/envs/cjy_llama/lib/python3.8/site-packages/transformers/trainer.py", line 1928, in _inner_training_loop for step, inputs in enumerate(epoch_iterator): File "/home/nlpir/miniconda3/envs/cjy_llama/lib/python3.8/site-packages/accelerate/data_loader.py", line 452, in iter current_batch = next(dataloader_iter) File "/home/nlpir/miniconda3/envs/cjy_llama/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 633, in next data = self._next_data() File "/home/nlpir/miniconda3/envs/cjy_llama/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 677, in _next_data data = self._dataset_fetcher.fetch(index) # may raise StopIteration File "/home/nlpir/miniconda3/envs/cjy_llama/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch data = [self.dataset[idx] for idx in possibly_batched_index] File "/home/nlpir/miniconda3/envs/cjy_llama/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 51, in data = [self.dataset[idx] for idx in possibly_batched_index] File "../finetune_llama3.py", line 255, in getitem ret = preprocess([self.raw_data[i]["conversations"]], self.tokenizer, self.max_len) File "../finetune_llama3.py", line 192, in preprocess assert len(input_id) == len(target) AssertionError 实际情况是input_id一直比target长度大1. 我在6块1080ti运行的,shell脚本内容如下: NCCL_P2P_DISABLE=1 \ NCCL_IB_DISABLE=1 \ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5 \ torchrun \ --nproc_per_node 6 \ --nnodes 1 \ --node_rank 0 \ --master_addr localhost \ --master_port 6601 \ ../finetune_llama3.py \ --model_name_or_path "../model_hub/LLM-Research/Meta-Llama-3-8B-Instruct/" \ --data_path "../data/Belle_sampled_qwen.json" \ --fp16 True \ --output_dir "../output/llama3_8B_qlora" \ --num_train_epochs 100 \ --per_device_train_batch_size 1 \ --per_device_eval_batch_size 1 \ --gradient_accumulation_steps 16 \ --evaluation_strategy "no" \ --save_strategy "steps" \ --save_steps 5 \ --save_total_limit 1 \ --learning_rate 1e-5 \ --weight_decay 0.1 \ --adam_beta2 0.95 \ --warmup_ratio 0.01 \ --lr_scheduler_type "cosine" \ --logging_steps 1 \ --report_to "none" \ --model_max_length 4096 \ --gradient_checkpointing True \ --lazy_preprocess True \ --deepspeed "../config/ds_config_zero2.json" \ --use_lora \ --load_in_4bit \ --q_lora 只是改了下CUDA_VISIBLE_DEVICES和nproc_per_node ,并且把bf16改为fp16.

taishan1994 commented 6 months ago

能打印下它们具体有哪些区别么。

taishan1994 commented 6 months ago

image 没啥问题呀。

CarlChang39 commented 6 months ago

能打印下它们具体有哪些区别么。

我保存了下中间结果,文件命名左边的是input_id的长度,右边是target的长度。反复运行都是这样。

image

186 185这个文件中:

input_id:[128000, 128006, 128000, 9125, 128007, 128000, 198, 128000, 2675, 527, 264, 55066, 6369, 6465, 889, 2744, 31680, 304, 55066, 6604, 0, 128009, 128006, 128000, 882, 128007, 128000, 198, 128000, 106161, 100815, 107015, 9554, 106246, 3922, 117805, 115532, 32943, 9554, 127944, 117633, 113333, 96455, 124671, 118402, 3922, 120605, 127944, 9554, 103572, 82317, 75863, 102654, 127198, 9554, 104654, 124662, 124778, 102924, 118742, 1811, 92672, 3922, 17792, 49792, 118034, 110745, 20675, 5486, 103458, 112352, 50667, 61826, 42399, 104696, 9554, 109589, 104587, 18184, 102208, 23039, 102208, 42052, 103652, 117661, 117724, 122705, 33748, 110053, 107644, 103624, 122943, 124858, 17297, 3922, 83687, 33208, 47770, 25287, 104724, 112743, 105231, 112157, 28190, 33764, 125648, 1811, 34226, 53901, 30590, 51611, 33764, 83687, 33208, 47770, 25287, 9554, 114099, 102778, 3922, 112026, 121915, 9554, 118556, 34208, 106246, 105000, 86206, 105231, 123882, 103229, 108199, 104696, 109189, 34208, 122705, 106556, 38741, 60843, 37985, 17905, 28190, 44388, 38574, 17161, 22656, 9554, 123092, 127555, 106015, 1811, 128009, 128006, 128000, 78191, 128007, 128000, 198, 128000, 107015, 5486, 127944, 5486, 17792, 49792, 118034, 110745, 20675, 5486, 103458, 112352, 5486, 83687, 33208, 47770, 25287, 5486, 105231, 5486, 104696, 109189, 5486, 122705, 106556, 1811, 128009]

target:[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 128000, 107015, 5486, 127944, 5486, 17792, 49792, 118034, 110745, 20675, 5486, 103458, 112352, 5486, 83687, 33208, 47770, 25287, 5486, 105231, 5486, 104696, 109189, 5486, 122705, 106556, 1811, 128009]

CarlChang39 commented 6 months ago

image 没啥问题呀。

image 又运行了一下,出问题的都是assistant。

CarlChang39 commented 6 months ago

我看了下您的代码 image _input_id和_target的长度差值应该是len(nl_tokens)-1?但我打印nl_tokens = [128000, 198],长度是2,那是否说明_input_id的长度一定会比_target大1呢?所以最后导致input_id比target大1。

taishan1994 commented 6 months ago

那里应该改成ignore token乘以n ltoken的长度 没考虑到nl token长度为2

CarlChang39 commented 6 months ago

那里应该改成ignore token乘以n ltoken的长度 没考虑到nl token长度为2

就是_target最后一个,tokenizer(value)前面的那个是吧?

taishan1994 commented 6 months ago

是的

CarlChang39 commented 6 months ago

是的

好的,应该是可以了,非常感谢