hiyouga / LLaMA-Factory

Unified Efficient Fine-Tuning of 100+ LLMs (ACL 2024)
https://arxiv.org/abs/2403.13372
Apache License 2.0
34k stars 4.18k forks source link

llava多模态模型使用kto + lora微调报错 #5202

Closed asadfgglie closed 2 months ago

asadfgglie commented 2 months ago

Reminder

System Info

Reproduction

llamafactory-cli train \
    --stage kto \
    --do_train True \
    --model_name_or_path llava-hf/llava-1.5-7b-hf \
    --preprocessing_num_workers 16 \
    --finetuning_type lora \
    --template llama2 \
    --flash_attn auto \
    --visual_inputs True \
    --dataset_dir data \
    --dataset kto_en_demo \
    --cutoff_len 1024 \
    --learning_rate 5e-05 \
    --num_train_epochs 1.0 \
    --max_samples 100000 \
    --per_device_train_batch_size 2 \
    --gradient_accumulation_steps 8 \
    --lr_scheduler_type cosine \
    --max_grad_norm 1.0 \
    --logging_steps 5 \
    --save_steps 100 \
    --warmup_steps 0 \
    --optim adamw_torch \
    --packing False \
    --report_to none \
    --output_dir saves/LLaVA1.5-7B-Chat/lora/testkto \
    --bf16 True \
    --plot_loss True \
    --ddp_timeout 180000000 \
    --include_num_input_tokens_seen True \
    --quantization_bit 8 \
    --quantization_method bitsandbytes \
    --lora_rank 8 \
    --lora_alpha 16 \
    --lora_dropout 0 \
    --lora_target all \
    --pref_beta 0.1 \
    --pref_ftx 0 \
    --pref_loss sigmoid

Expected behavior

使用llava类型的多模态模型并且选择kto + lora微调,数据集使用范例kto数据集,在训练时会产生以下错误:

[2024-08-17 02:51:59,327] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)
08/17/2024 02:52:01 - WARNING - llamafactory.hparams.parser - We recommend enable `upcast_layernorm` in quantized training.
08/17/2024 02:52:01 - INFO - llamafactory.hparams.parser - Process rank: 0, device: cuda:0, n_gpu: 1, distributed training: False, compute dtype: torch.bfloat16
[INFO|tokenization_utils_base.py:2289] 2024-08-17 02:52:01,700 >> loading file tokenizer.model from cache at /home/asadfgglie/.cache/huggingface/hub/models--llava-hf--llava-1.5-7b-hf/snapshots/fa3dd2809b8de6327002947c3382260de45015d4/tokenizer.model
[INFO|tokenization_utils_base.py:2289] 2024-08-17 02:52:01,700 >> loading file tokenizer.json from cache at /home/asadfgglie/.cache/huggingface/hub/models--llava-hf--llava-1.5-7b-hf/snapshots/fa3dd2809b8de6327002947c3382260de45015d4/tokenizer.json
[INFO|tokenization_utils_base.py:2289] 2024-08-17 02:52:01,700 >> loading file added_tokens.json from cache at /home/asadfgglie/.cache/huggingface/hub/models--llava-hf--llava-1.5-7b-hf/snapshots/fa3dd2809b8de6327002947c3382260de45015d4/added_tokens.json
[INFO|tokenization_utils_base.py:2289] 2024-08-17 02:52:01,700 >> loading file special_tokens_map.json from cache at /home/asadfgglie/.cache/huggingface/hub/models--llava-hf--llava-1.5-7b-hf/snapshots/fa3dd2809b8de6327002947c3382260de45015d4/special_tokens_map.json
[INFO|tokenization_utils_base.py:2289] 2024-08-17 02:52:01,700 >> loading file tokenizer_config.json from cache at /home/asadfgglie/.cache/huggingface/hub/models--llava-hf--llava-1.5-7b-hf/snapshots/fa3dd2809b8de6327002947c3382260de45015d4/tokenizer_config.json
[INFO|tokenization_utils_base.py:2533] 2024-08-17 02:52:01,739 >> Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
[INFO|image_processing_base.py:375] 2024-08-17 02:52:02,387 >> loading configuration file preprocessor_config.json from cache at /home/asadfgglie/.cache/huggingface/hub/models--llava-hf--llava-1.5-7b-hf/snapshots/fa3dd2809b8de6327002947c3382260de45015d4/preprocessor_config.json
[INFO|image_processing_base.py:375] 2024-08-17 02:52:02,585 >> loading configuration file preprocessor_config.json from cache at /home/asadfgglie/.cache/huggingface/hub/models--llava-hf--llava-1.5-7b-hf/snapshots/fa3dd2809b8de6327002947c3382260de45015d4/preprocessor_config.json
[INFO|image_processing_base.py:429] 2024-08-17 02:52:02,586 >> Image processor CLIPImageProcessor {
  "crop_size": {
    "height": 336,
    "width": 336
  },
  "do_center_crop": true,
  "do_convert_rgb": true,
  "do_normalize": true,
  "do_rescale": true,
  "do_resize": true,
  "image_mean": [
    0.48145466,
    0.4578275,
    0.40821073
  ],
  "image_processor_type": "CLIPImageProcessor",
  "image_std": [
    0.26862954,
    0.26130258,
    0.27577711
  ],
  "processor_class": "LlavaProcessor",
  "resample": 3,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "shortest_edge": 336
  }
}

[INFO|tokenization_utils_base.py:2289] 2024-08-17 02:52:02,785 >> loading file tokenizer.model from cache at /home/asadfgglie/.cache/huggingface/hub/models--llava-hf--llava-1.5-7b-hf/snapshots/fa3dd2809b8de6327002947c3382260de45015d4/tokenizer.model
[INFO|tokenization_utils_base.py:2289] 2024-08-17 02:52:02,785 >> loading file tokenizer.json from cache at /home/asadfgglie/.cache/huggingface/hub/models--llava-hf--llava-1.5-7b-hf/snapshots/fa3dd2809b8de6327002947c3382260de45015d4/tokenizer.json
[INFO|tokenization_utils_base.py:2289] 2024-08-17 02:52:02,785 >> loading file added_tokens.json from cache at /home/asadfgglie/.cache/huggingface/hub/models--llava-hf--llava-1.5-7b-hf/snapshots/fa3dd2809b8de6327002947c3382260de45015d4/added_tokens.json
[INFO|tokenization_utils_base.py:2289] 2024-08-17 02:52:02,785 >> loading file special_tokens_map.json from cache at /home/asadfgglie/.cache/huggingface/hub/models--llava-hf--llava-1.5-7b-hf/snapshots/fa3dd2809b8de6327002947c3382260de45015d4/special_tokens_map.json
[INFO|tokenization_utils_base.py:2289] 2024-08-17 02:52:02,785 >> loading file tokenizer_config.json from cache at /home/asadfgglie/.cache/huggingface/hub/models--llava-hf--llava-1.5-7b-hf/snapshots/fa3dd2809b8de6327002947c3382260de45015d4/tokenizer_config.json
[INFO|tokenization_utils_base.py:2533] 2024-08-17 02:52:02,817 >> Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
[INFO|processing_utils.py:722] 2024-08-17 02:52:03,253 >> Processor LlavaProcessor:
- image_processor: CLIPImageProcessor {
  "crop_size": {
    "height": 336,
    "width": 336
  },
  "do_center_crop": true,
  "do_convert_rgb": true,
  "do_normalize": true,
  "do_rescale": true,
  "do_resize": true,
  "image_mean": [
    0.48145466,
    0.4578275,
    0.40821073
  ],
  "image_processor_type": "CLIPImageProcessor",
  "image_std": [
    0.26862954,
    0.26130258,
    0.27577711
  ],
  "processor_class": "LlavaProcessor",
  "resample": 3,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "shortest_edge": 336
  }
}

- tokenizer: LlamaTokenizerFast(name_or_path='llava-hf/llava-1.5-7b-hf', vocab_size=32000, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='left', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '<pad>'}, clean_up_tokenization_spaces=False),  added_tokens_decoder={
        0: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
        1: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
        2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
        32000: AddedToken("<image>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
        32001: AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}

{
  "chat_template": "{% for message in messages %}{% if message['role'] != 'system' %}{{ message['role'].upper() + ': '}}{% endif %}{# Render all images first #}{% for content in message['content'] | selectattr('type', 'equalto', 'image') %}{{ '<image>\n' }}{% endfor %}{# Render all text next #}{% if message['role'] != 'assistant' %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{{ content['text'] + ' '}}{% endfor %}{% else %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{% generation %}{{ content['text'] + ' '}}{% endgeneration %}{% endfor %}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT:' }}{% endif %}",
  "processor_class": "LlavaProcessor"
}

08/17/2024 02:52:03 - INFO - llamafactory.data.loader - Loading dataset kto_en_demo.json...
training example:
input_ids:
[1, 518, 25580, 29962, 29871, 32000, 450, 14879, 27226, 11444, 338, 2675, 1156, 805, 29891, 2519, 29892, 9978, 310, 6601, 775, 393, 2601, 6053, 373, 278, 23226, 310, 4685, 4160, 304, 5702, 470, 766, 6685, 1009, 7395, 14188, 29889, 13, 29984, 29901, 8449, 338, 278, 1900, 15837, 310, 445, 4274, 29973, 13, 29925, 860, 596, 1234, 515, 29901, 13, 29898, 29909, 467, 2787, 13, 29898, 29933, 467, 12453, 13, 29898, 29907, 467, 15197, 13, 29898, 29928, 467, 9327, 29914, 29911, 5309, 13, 29902, 1348, 278, 1234, 338, 518, 29914, 25580, 29962, 438, 1148, 288, 1148, 21023, 21023, 29991, 334, 29887, 335, 6234, 29930, 2803, 592, 1348, 856, 438, 1148, 288, 1148, 21023, 21023, 29991, 334, 29887, 335, 6234, 29930, 450, 1900, 15837, 310, 445, 4274, 338, 856, 334, 29881, 5848, 1245, 29930, 856, 360, 29991, 9327, 29914, 29911, 5309, 29991, 612, 388, 29991, 334, 5527, 9890, 29930, 450, 14879, 27226, 11444, 338, 9963, 1048, 805, 29891, 2519, 29892, 607, 338, 763, 263, 2217, 6601, 6494, 393, 508, 6505, 825, 366, 437, 373, 596, 6601, 1728, 366, 13797, 29889, 739, 29915, 29879, 763, 263, 7035, 10823, 29892, 541, 451, 263, 7575, 697, 29991, 334, 29887, 4692, 29930, 450, 383, 9472, 10753, 304, 5040, 278, 805, 29891, 2519, 515, 2599, 967, 2655, 29892, 577, 896, 29915, 276, 2675, 1156, 372, 29991, 334, 1173, 261, 29930, 14962, 1148, 3634, 29991, 2]
inputs:
<s> [INST] <image> The Federal Trade Commission is going after spyware, bits of computer code that install themselves on the computers of Internet users to track or disrupt their online activities.
Q: Which is the best summary of this article?
Pick your answer from:
(A). World
(B). Sports
(C). Business
(D). Science/Tech
I think the answer is [/INST] Ooh ooh ah ah! *giggle* Let me think... Ooh ooh ah ah! *giggle* The best summary of this article is... *drumroll*... D! Science/Tech! Yay! *confetti* The Federal Trade Commission is talking about spyware, which is like a little computer bug that can watch what you do on your computer without you knowing. It's like a secret agent, but not a nice one! *gasp* The FTC wants to stop the spyware from doing its thing, so they're going after it! *cheer* Woohoo!</s>
label_ids:
[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 438, 1148, 288, 1148, 21023, 21023, 29991, 334, 29887, 335, 6234, 29930, 2803, 592, 1348, 856, 438, 1148, 288, 1148, 21023, 21023, 29991, 334, 29887, 335, 6234, 29930, 450, 1900, 15837, 310, 445, 4274, 338, 856, 334, 29881, 5848, 1245, 29930, 856, 360, 29991, 9327, 29914, 29911, 5309, 29991, 612, 388, 29991, 334, 5527, 9890, 29930, 450, 14879, 27226, 11444, 338, 9963, 1048, 805, 29891, 2519, 29892, 607, 338, 763, 263, 2217, 6601, 6494, 393, 508, 6505, 825, 366, 437, 373, 596, 6601, 1728, 366, 13797, 29889, 739, 29915, 29879, 763, 263, 7035, 10823, 29892, 541, 451, 263, 7575, 697, 29991, 334, 29887, 4692, 29930, 450, 383, 9472, 10753, 304, 5040, 278, 805, 29891, 2519, 515, 2599, 967, 2655, 29892, 577, 896, 29915, 276, 2675, 1156, 372, 29991, 334, 1173, 261, 29930, 14962, 1148, 3634, 29991, 2]
labels:
Ooh ooh ah ah! *giggle* Let me think... Ooh ooh ah ah! *giggle* The best summary of this article is... *drumroll*... D! Science/Tech! Yay! *confetti* The Federal Trade Commission is talking about spyware, which is like a little computer bug that can watch what you do on your computer without you knowing. It's like a secret agent, but not a nice one! *gasp* The FTC wants to stop the spyware from doing its thing, so they're going after it! *cheer* Woohoo!</s>
[INFO|configuration_utils.py:733] 2024-08-17 02:52:04,334 >> loading configuration file config.json from cache at /home/asadfgglie/.cache/huggingface/hub/models--llava-hf--llava-1.5-7b-hf/snapshots/fa3dd2809b8de6327002947c3382260de45015d4/config.json
[INFO|configuration_utils.py:800] 2024-08-17 02:52:04,337 >> Model config LlavaConfig {
  "_name_or_path": "llava-hf/llava-1.5-7b-hf",
  "architectures": [
    "LlavaForConditionalGeneration"
  ],
  "ignore_index": -100,
  "image_token_index": 32000,
  "model_type": "llava",
  "pad_token_id": 32001,
  "projector_hidden_act": "gelu",
  "text_config": {
    "_name_or_path": "lmsys/vicuna-7b-v1.5",
    "architectures": [
      "LlamaForCausalLM"
    ],
    "max_position_embeddings": 4096,
    "model_type": "llama",
    "rms_norm_eps": 1e-05,
    "torch_dtype": "float16",
    "vocab_size": 32064
  },
  "tie_word_embeddings": false,
  "torch_dtype": "float16",
  "transformers_version": "4.43.2",
  "vision_config": {
    "hidden_size": 1024,
    "image_size": 336,
    "intermediate_size": 4096,
    "model_type": "clip_vision_model",
    "num_attention_heads": 16,
    "num_hidden_layers": 24,
    "patch_size": 14,
    "projection_dim": 768,
    "vocab_size": 32000
  },
  "vision_feature_layer": -2,
  "vision_feature_select_strategy": "default",
  "vocab_size": 32064
}

08/17/2024 02:52:04 - INFO - llamafactory.model.model_utils.quantization - Quantizing model to 8 bit with bitsandbytes.
[INFO|modeling_utils.py:3634] 2024-08-17 02:52:04,341 >> loading weights file model.safetensors from cache at /home/asadfgglie/.cache/huggingface/hub/models--llava-hf--llava-1.5-7b-hf/snapshots/fa3dd2809b8de6327002947c3382260de45015d4/model.safetensors.index.json
[INFO|modeling_utils.py:1572] 2024-08-17 02:52:04,342 >> Instantiating LlavaForConditionalGeneration model under default dtype torch.bfloat16.
[INFO|configuration_utils.py:1038] 2024-08-17 02:52:04,342 >> Generate config GenerationConfig {
  "pad_token_id": 32001
}

[INFO|configuration_utils.py:1038] 2024-08-17 02:52:04,518 >> Generate config GenerationConfig {
  "bos_token_id": 1,
  "eos_token_id": 2
}

Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:09<00:00,  3.07s/it]
[INFO|modeling_utils.py:4463] 2024-08-17 02:52:13,887 >> All model checkpoint weights were used when initializing LlavaForConditionalGeneration.

[INFO|modeling_utils.py:4471] 2024-08-17 02:52:13,888 >> All the weights of LlavaForConditionalGeneration were initialized from the model checkpoint at llava-hf/llava-1.5-7b-hf.
If your task is similar to the task the model of the checkpoint was trained on, you can already use LlavaForConditionalGeneration for predictions without further training.
[INFO|configuration_utils.py:993] 2024-08-17 02:52:14,092 >> loading configuration file generation_config.json from cache at /home/asadfgglie/.cache/huggingface/hub/models--llava-hf--llava-1.5-7b-hf/snapshots/fa3dd2809b8de6327002947c3382260de45015d4/generation_config.json
[INFO|configuration_utils.py:1038] 2024-08-17 02:52:14,092 >> Generate config GenerationConfig {
  "bos_token_id": 1,
  "eos_token_id": 2,
  "pad_token_id": 32001
}

08/17/2024 02:52:14 - INFO - llamafactory.model.model_utils.visual - Casting multimodal projector outputs in torch.bfloat16.
08/17/2024 02:52:14 - INFO - llamafactory.model.model_utils.checkpointing - Gradient checkpointing enabled.
08/17/2024 02:52:14 - INFO - llamafactory.model.model_utils.attention - Using torch SDPA for faster training and inference.
08/17/2024 02:52:14 - INFO - llamafactory.model.adapter - Upcasting trainable params to float32.
08/17/2024 02:52:14 - INFO - llamafactory.model.adapter - Fine-tuning method: LoRA
08/17/2024 02:52:14 - INFO - llamafactory.model.model_utils.misc - Found linear modules: o_proj,k_proj,q_proj,down_proj,gate_proj,v_proj,up_proj
08/17/2024 02:52:14 - INFO - llamafactory.model.loader - trainable params: 19,988,480 || all params: 7,083,415,552 || trainable%: 0.2822
[INFO|trainer.py:648] 2024-08-17 02:52:14,487 >> Using auto half precision backend
[INFO|trainer.py:2134] 2024-08-17 02:52:14,660 >> ***** Running training *****
[INFO|trainer.py:2135] 2024-08-17 02:52:14,661 >>   Num examples = 300
[INFO|trainer.py:2136] 2024-08-17 02:52:14,661 >>   Num Epochs = 1
[INFO|trainer.py:2137] 2024-08-17 02:52:14,661 >>   Instantaneous batch size per device = 2
[INFO|trainer.py:2140] 2024-08-17 02:52:14,661 >>   Total train batch size (w. parallel, distributed & accumulation) = 16
[INFO|trainer.py:2141] 2024-08-17 02:52:14,661 >>   Gradient Accumulation steps = 8
[INFO|trainer.py:2142] 2024-08-17 02:52:14,661 >>   Total optimization steps = 18
[INFO|trainer.py:2143] 2024-08-17 02:52:14,663 >>   Number of trainable parameters = 19,988,480
  0%|                                                                                                                                    | 0/18 [00:00<?, ?it/s]Traceback (most recent call last):
  File "/home/asadfgglie/LLaMA-Factory/venv/bin/llamafactory-cli", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/home/asadfgglie/LLaMA-Factory/src/llamafactory/cli.py", line 111, in main
    run_exp()
  File "/home/asadfgglie/LLaMA-Factory/src/llamafactory/train/tuner.py", line 58, in run_exp
    run_kto(model_args, data_args, training_args, finetuning_args, callbacks)
  File "/home/asadfgglie/LLaMA-Factory/src/llamafactory/train/kto/workflow.py", line 76, in run_kto
    train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/asadfgglie/LLaMA-Factory/venv/lib/python3.11/site-packages/transformers/trainer.py", line 1938, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/asadfgglie/LLaMA-Factory/venv/lib/python3.11/site-packages/transformers/trainer.py", line 2279, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/asadfgglie/LLaMA-Factory/venv/lib/python3.11/site-packages/transformers/trainer.py", line 3318, in training_step
    loss = self.compute_loss(model, inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/asadfgglie/LLaMA-Factory/venv/lib/python3.11/site-packages/trl/trainer/kto_trainer.py", line 1357, in compute_loss
    loss, metrics = self.get_batch_loss_metrics(model, inputs)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/asadfgglie/LLaMA-Factory/src/llamafactory/train/kto/trainer.py", line 188, in get_batch_loss_metrics
    self.concatenated_forward(model, batch)
  File "/home/asadfgglie/LLaMA-Factory/src/llamafactory/train/kto/trainer.py", line 146, in concatenated_forward
    target_logps, target_logps_avg = self.forward(model, batch)
                                     ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/asadfgglie/LLaMA-Factory/src/llamafactory/train/kto/trainer.py", line 140, in forward
    logps, valid_length = get_batch_logps(logits=logits, labels=batch["{}labels".format(prefix)])
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/asadfgglie/LLaMA-Factory/src/llamafactory/train/trainer_utils.py", line 420, in get_batch_logps
    raise ValueError("Logits (batchsize x seqlen) and labels must have the same shape.")
ValueError: Logits (batchsize x seqlen) and labels must have the same shape.
  0%|                                                                                                                                    | 0/18 [00:02<?, ?it/s]

Others

No response

hiyouga commented 2 months ago

fixed: https://github.com/hiyouga/LLaMA-Factory/blob/main/examples/train_lora/qwen2vl_lora_dpo.yaml