mbzuai-oryx / Video-ChatGPT

[ACL 2024 🔥] Video-ChatGPT is a video conversation model capable of generating meaningful conversation about videos. It combines the capabilities of LLMs with a pretrained visual encoder adapted for spatiotemporal video representation. We also introduce a rigorous 'Quantitative Evaluation Benchmarking' for video-based conversational models.
https://mbzuai-oryx.github.io/Video-ChatGPT
Creative Commons Attribution 4.0 International
1.17k stars 102 forks source link

TypeError: forward () got an unexpected keyword argument "position_ids" #19

Closed GZHU-DVL closed 1 year ago

GZHU-DVL commented 1 year ago

1687915914102 According to the tutorial, I can execute this project, but the execution will report an error when I reach this position.

mmaaz60 commented 1 year ago

Hi @GZHU-DVL,

Thank you for your interest in our work. Please make sure that you followed the mentioned environment setup process and using the correct versions of the libraries.

If the issue still exists, please provide the script and command that you are running to understand the issue.

I hope it will help. Thanks

GZHU-DVL commented 1 year ago

The versions of the libraries are as follows: torch~=2.0.0 tqdm~=4.65.0 transformers numpy~=1.23 Pillow~=9.5.0 decord~=0.6.0 gradio~=3.23.0 markdown2~=2.4.8 einops~=0.6.1 requests~=2.30.0 sentencepiece~=0.1.99 protobuf~=4.23.2 accelerate~=0.20.3 accelerate==0.19.0 tokenizers>=0.13.3

The command are as follows: torchrun video_chatgpt/train/train_mem.py \ --model_name_or_path /gemini/data-2/7b/ \ --version v1 \ --data_path /gemini/code/Video-ChatGPT/scripts/video_chatgpt_training.json \ --video_folder /gemini/data-2/ActivityNet_Train_Video-ChatGPT_Clip-L14_Features/activity_clip-14L_spatio_temporal_356/ \ --tune_mm_mlp_adapter True \ --mm_use_vid_start_end \ --bf16 True \ --output_dir ./Video-ChatGPT_7B-1.1_Checkpoints \ --num_train_epochs 3 \ --per_device_train_batch_size 1 \ --per_device_eval_batch_size 1 \ --gradient_accumulation_steps 1 \ --evaluation_strategy "no" \ --save_strategy "steps" \ --save_steps 3000 \ --save_total_limit 3 \ --learning_rate 2e-5 \ --weight_decay 0. \ --warmup_ratio 0.03 \ --lr_scheduler_type "cosine" \ --logging_steps 100 \ --tf32 True \ --model_max_length 2048 \ --gradient_checkpointing True \ --lazy_preprocess True

GZHU-DVL commented 1 year ago

The problem was solved after I changed the version of Transformer.

whcpumpkin commented 1 year ago

The problem was solved after I changed the version of Transformer.

Hi! I faced the same problem. Could you tell me which version of Transformers you use? Thanks!

zhanwenchen commented 11 months ago

The root cause can be seen in this issue: https://github.com/huggingface/transformers/issues/24130

zhanwenchen commented 11 months ago

The root cause can be seen in this issue: huggingface/transformers#24130

Actually, I was wrong. The problem is with the flash_attn monkey patch not being updated to reflect the breaking code changes in transformers. To fix this, update the llama_flash_attn_monkey_patch.py in this repository to match this one: https://github.com/lm-sys/FastChat/blob/dd84d166d7694f0cc0c766e5a811d995f5801c77/fastchat/train/llama_flash_attn_monkey_patch.py

The specific commit with this fix is this one: https://github.com/lm-sys/FastChat/commit/daa9c11080ceced2bd52c3e0027e4f64b1512683

But after that, you also need to add a kwarg, padding_mask: Optional[torch.LongTensor] = None, in the forward like this (if the FastChat repo hasn't when you see this):

# ...video_chatgpt/train/llama_flash_attn_monkey_patch.py
...
def forward(
    self,
    hidden_states: torch_Tensor,
    attention_mask: Optional[torch_Tensor] = None,
    position_ids: Optional[torch_Tensor] = None,
    past_key_value: Optional[Tuple[torch_Tensor]] = None,
    output_attentions: bool = False,
    use_cache: bool = False,
    padding_mask: Optional[torch_LongTensor] = None,
) -> Tuple[torch_Tensor, Optional[torch_Tensor], Optional[Tuple[torch_Tensor]]]:
    if output_attentions:
...

(Ignore my _s. Treat them as . I enjoy silly runtime optimizations).