huangb23 / VTimeLLM

[CVPR'2024 Highlight] Official PyTorch implementation of the paper "VTimeLLM: Empower LLM to Grasp Video Moments".
https://arxiv.org/pdf/2311.18445.pdf
Other
205 stars 11 forks source link

RuntimeError: cu_seqlens_q must have shape (batch_size + 1) #11

Closed KlayMa527 closed 8 months ago

KlayMa527 commented 8 months ago

Traceback (most recent call last): File "vtimellm/train/train_mem.py", line 20, in train() File "/root/autodl-tmp/VTimeLLM/vtimellm/train/train.py", line 353, in train trainer.train() File "/root/miniconda3/lib/python3.8/site-packages/transformers/trainer.py", line 1537, in train return inner_training_loop( File "/root/miniconda3/lib/python3.8/site-packages/transformers/trainer.py", line 1854, in _inner_training_loop tr_loss_step = self.training_step(model, inputs) File "/root/miniconda3/lib/python3.8/site-packages/transformers/trainer.py", line 2735, in training_step loss = self.compute_loss(model, inputs) File "/root/miniconda3/lib/python3.8/site-packages/transformers/trainer.py", line 2758, in compute_loss outputs = model(inputs) File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, *kwargs) File "/root/miniconda3/lib/python3.8/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn ret_val = func(args, kwargs) File "/root/miniconda3/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 1833, in forward loss = self.module(*inputs, kwargs) File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl result = forward_call(*args, *kwargs) File "/root/autodl-tmp/VTimeLLM/vtimellm/train/../../vtimellm/model/vtimellm.py", line 171, in forward outputs = self.model( File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl result = forward_call(args, kwargs) File "/root/miniconda3/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 1058, in forward layer_outputs = self._gradient_checkpointing_func( File "/root/miniconda3/lib/python3.8/site-packages/torch/utils/checkpoint.py", line 249, in checkpoint return CheckpointFunction.apply(function, preserve, args) File "/root/miniconda3/lib/python3.8/site-packages/torch/autograd/function.py", line 506, in apply return super().apply(args, kwargs) # type: ignore[misc] File "/root/miniconda3/lib/python3.8/site-packages/torch/utils/checkpoint.py", line 107, in forward outputs = run_function(args) File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl result = forward_call(args, kwargs) File "/root/miniconda3/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 796, in forward hidden_states, self_attn_weights, present_key_value = self.self_attn( File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl result = forward_call(*args, *kwargs) File "/root/autodl-tmp/VTimeLLM/vtimellm/train/llama_flash_attn_monkey_patch.py", line 92, in forward output_unpad = flash_attn_unpadded_qkvpacked_func( File "/root/miniconda3/lib/python3.8/site-packages/flash_attn/flash_attn_interface.py", line 887, in flash_attn_varlen_qkvpacked_func return FlashAttnVarlenQKVPackedFunc.apply( File "/root/miniconda3/lib/python3.8/site-packages/torch/autograd/function.py", line 506, in apply return super().apply(args, **kwargs) # type: ignore[misc] File "/root/miniconda3/lib/python3.8/site-packages/flash_attn/flash_attn_interface.py", line 288, in forward out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward( File "/root/miniconda3/lib/python3.8/site-packages/flash_attn/flash_attn_interface.py", line 85, in _flash_attn_varlen_forward out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd( RuntimeError: cu_seqlens_q must have shape (batch_size + 1)

I followed the instructions to set data and feat, but encountered the following error during stage 1 training. I would appreciate it if you could help me answer this questions.

KlayMa527 commented 8 months ago

I solve this problem by this link https://github.com/Dao-AILab/flash-attention/issues/742