TideDra / VL-RLHF

A RLHF Infrastructure for Vision-Language Models
Apache License 2.0
77 stars 4 forks source link

微调LLaVA报错 #6

Open njucckevin opened 2 months ago

njucckevin commented 2 months ago
[rank1]: Traceback (most recent call last):
[rank1]:   File "/home/nfs04/chengkz/VL-RLHF/src/vlrlhf/dpo.py", line 146, in <module>
[rank1]:     dpo_trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
[rank1]:   File "/home/data_91_d/anaconda3/envs/chengkz_lvlm/lib/python3.10/site-packages/transformers/trainer.py", line 1885, in train
[rank1]:     return inner_training_loop(
[rank1]:   File "/home/data_91_d/anaconda3/envs/chengkz_lvlm/lib/python3.10/site-packages/transformers/trainer.py", line 2216, in _inner_training_loop
[rank1]:     tr_loss_step = self.training_step(model, inputs)
[rank1]:   File "/home/nfs04/chengkz/VL-RLHF/src/vlrlhf/base/trainer.py", line 305, in training_step
[rank1]:     loss_step = super().training_step(model, inputs)
[rank1]:   File "/home/data_91_d/anaconda3/envs/chengkz_lvlm/lib/python3.10/site-packages/transformers/trainer.py", line 3238, in training_step
[rank1]:     loss = self.compute_loss(model, inputs)
[rank1]:   File "/home/data_91_d/anaconda3/envs/chengkz_lvlm/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py", line 1081, in compute_loss
[rank1]:     loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
[rank1]:   File "/home/data_91_d/anaconda3/envs/chengkz_lvlm/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py", line 1022, in get_batch_loss_metrics
[rank1]:     ) = self.concatenated_forward(model, batch)
[rank1]:   File "/home/nfs04/chengkz/VL-RLHF/src/vlrlhf/models/Llava/__init__.py", line 502, in concatenated_forward
[rank1]:     pixel_values=concatenated_batch["pixel_values"],
[rank1]: KeyError: 'pixel_values'

请问这个报错该如何解决?

TideDra commented 2 months ago

已修复,感谢您的反馈。

njucckevin commented 2 months ago

作者您好,除了在benchmark上测试之外,或许有推理代码可供参考吗?即对于预训练的ckpt或微调之后的ckpt,进行简单的单样本推理。或者我应该参考哪个仓库/模型的代码? 很棒的工作,感谢~

TideDra commented 2 months ago

作者您好,除了在benchmark上测试之外,或许有推理代码可供参考吗?即对于预训练的ckpt或微调之后的ckpt,进行简单的单样本推理。或者我应该参考哪个仓库/模型的代码? 很棒的工作,感谢~

您可以使用src/vlrlhf/eval.utils.py中提供的相关接口:

from vlrlhf.eval.utils import load_model_and_processor
model,processor,generation_kwargs = load_model_and_processor(YourModelPath,None)
image_path = 'a.jpg'
prompt = 'Describe this image'
prompt = processor.format_multimodal_prompt(prompt,image_path)
inputs = processor(texts=[prompt], images_path=[image_path], check_format=False)
inputs.pop('label',None)
outputs = model.generate(**inputs, use_cache=True, **generation_kwargs)
njucckevin commented 1 month ago

想问下现在的代码仓库支持KTO吗,我看scripts里面有kto相关的脚本,例如kto_qwenvl?如果还不支持的话,后续有计划吗 感谢~