mutonix / Vript

Other
112 stars 3 forks source link

run the demo in A100 80G: CUDA out of memory #6

Closed LAW1223 closed 1 month ago

mutonix commented 1 month ago

Can you offer more information about how you get OOM? You can constrain the output length of the model to avoid outputting too long text, which may lead to OOM.

LAW1223 commented 1 month ago

Thanks a lot. The config of generation is: chat.upload_video(video, chat_state, img_list, 64, text=prompt) chat.ask("###Human: " + prompt + " ###Assistant: ", chat_state) llm_message = chat.answer(conv=chat_state, img_list=img_list, num_beams=3, do_sample=True, max_new_tokens=2048, top_p=0.9, repetition_penalty=1.5, max_length=8192)[0] python demo.py \ --video-path /data/captioning/Vript/emoji_1.mp4 \ --cfg-path config/vriptor_stllm_stage2.yaml \ --gpu-id 0 \ --ckpt-path /mnt/models/Vriptor-STLLM

In the config/vriptor_stllm_stage2.yaml, I changed the   q_former_model to my path: '/mnt/captioning/Vript/instruct_blip_vicuna7b_trimmed.pth'
mutonix commented 1 month ago

You can reduce the max_length from 8192 to 1024.

LAW1223 commented 1 month ago

When I changed it: model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) File "/opt/conda/envs/vriptor/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 1234, in prepare_inputs_for_generation past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() IndexError: index 0 is out of bounds for dimension 0 with size 0

mutonix commented 1 month ago

Maybe 1024 is too short for video input. Please try to set the max_length to 4096 or set max_new_tokens=1024 instead of setting max_length.