hiyouga / LLaMA-Factory

Unify Efficient Fine-Tuning of 100+ LLMs
Apache License 2.0
25.52k stars 3.16k forks source link

Qwen2 debug 发现 labels全为-100 #4635

Closed xjtulien closed 2 days ago

xjtulien commented 2 days ago

Reminder

System Info

llamafactory-cli train \ --stage sft \ --do_train True \ --model_name_or_path /home/nince9/Pycharm/Qwen/checkpoint/eval_2024-06-14-16-28-44 \ --preprocessing_num_workers 16 \ --finetuning_type lora \ --template qwen \ --flash_attn auto \ --dataset_dir data \ --dataset my_train_data \ --cutoff_len 1024 \ --learning_rate 0.0002 \ --num_train_epochs 12.0 \ --max_samples 100000 \ --per_device_train_batch_size 2 \ --gradient_accumulation_steps 8 \ --lr_scheduler_type cosine \ --max_grad_norm 1.0 \ --logging_steps 5 \ --save_steps 100 \ --warmup_steps 0 \ --optim adamw_torch \ --packing False \ --report_to none \ --output_dir saves/Qwen2-7B-Chat/lora/train_2024-06-30-15-47-42 \ --bf16 True \ --plot_loss True \ --ddp_timeout 180000000 \ --include_num_input_tokens_seen True \ --lora_rank 8 \ --lora_alpha 16 \ --lora_dropout 0 \ --lora_target all

Reproduction

transformers 4.41.2 --model_name_or_path /home/nince9/Pycharm/Qwen/checkpoint/eval_2024-06-14-16-28-44 \路径是经过lora SFT训练后合并lora adapters导出的路径

Expected behavior

No response

Others

经过debug发现,在transformers/models/qwen2/modeling_qwen2.py中1166行代码 ` if labels is not None:

Shift so that tokens < n predict n

        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        # Flatten the tokens
        loss_fct = CrossEntropyLoss()
        shift_logits = shift_logits.view(-1, self.config.vocab_size)
        shift_labels = shift_labels.view(-1)
        # Enable model parallelism
        shift_labels = shift_labels.to(shift_logits.device)
        loss = loss_fct(shift_logits, shift_labels)`

所有的labels都为-100,并与logits做了交叉熵损失计算,请问是哪里出了问题,怎么解决?labels不应该对应输出的token_id吗? 非常希望作者有空回复

hiyouga commented 2 days ago

不会全是 -100

xjtulien commented 1 day ago

屏幕截图 2024-07-01 204740 断点调试到这里发现labels就是全为0呀