OpenGVLab / Ask-Anything

[CVPR2024 Highlight][VideoChatGPT] ChatGPT with video understanding! And many more supported LMs such as miniGPT4, StableLM, and MOSS.
https://vchat.opengvlab.com/
MIT License
2.86k stars 230 forks source link

Cannot run inference using API for video chat with Stable_LM #70

Open plischwe opened 8 months ago

plischwe commented 8 months ago

I have successfully launched the "Ask Anything with StableLM" model using public url from gradio. But I'm running into an error when using an example request from gradio. The file I am running is here: from gradio_client import Client

client = Client("https://3774b146370bec32fe.gradio.live/") result = client.predict( "https://github.com/gradio-app/gradio/raw/main/test/test_files/video_sample.mp4", # str (filepath on your computer (or URL) of file) in 'Input Video' Video component "Howdy!", # str in 'User Prompt (Optional, Enter with commas)' Textbox component fn_index=4 ) print(result)

The error that I am receiving is seen here: Traceback (most recent call last): File "/home/anaconda3/envs/videochat/lib/python3.9/site-packages/gradio/queueing.py", line 388, in call_prediction output = await route_utils.call_process_api( File "/home/anaconda3/envs/videochat/lib/python3.9/site-packages/gradio/route_utils.py", line 217, in call_process_api output = await app.get_blocks().process_api( File "/home/anaconda3/envs/videochat/lib/python3.9/site-packages/gradio/blocks.py", line 1554, in process_api result = await self.call_function( File "/home/anaconda3/envs/videochat/lib/python3.9/site-packages/gradio/blocks.py", line 1192, in call_function prediction = await anyio.to_thread.run_sync( File "/home/anaconda3/envs/videochat/lib/python3.9/site-packages/anyio/to_thread.py", line 33, in run_sync return await get_asynclib().run_sync_in_worker_thread( File "/home/anaconda3/envs/videochat/lib/python3.9/site-packages/anyio/_backends/_asyncio.py", line 877, in run_sync_in_worker_thread return await future File "/home/anaconda3/envs/videochat/lib/python3.9/site-packages/anyio/_backends/_asyncio.py", line 807, in run result = context.run(func, args) File "/home/anaconda3/envs/videochat/lib/python3.9/site-packages/gradio/utils.py", line 659, in wrapper response = f(args, kwargs) File "/home/plischwe/Ask-Anything/video_chat_with_StableLM/app.py", line 77, in inference caption, tag_predict = model.generate(image,tag_input = input_tag_list,max_length = 50, return_tag_predict = True) File "/home/plischwe/Ask-Anything/video_chat_with_StableLM/models/tag2text.py", line 200, in generate outputs = self.text_decoder.generate(input_ids=input_ids, File "/home/anaconda3/envs/videochat/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, *kwargs) File "/home/anaconda3/envs/videochat/lib/python3.9/site-packages/transformers/generation/utils.py", line 1685, in generate return self.beam_search( File "/home/anaconda3/envs/videochat/lib/python3.9/site-packages/transformers/generation/utils.py", line 3024, in beam_search outputs = self( File "/home/anaconda3/envs/videochat/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(args, kwargs) File "/home/plischwe/Ask-Anything/video_chat_with_StableLM/models/med.py", line 962, in forward outputs = self.bert( File "/home/anaconda3/envs/videochat/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, kwargs) File "/home/plischwe/Ask-Anything/video_chat_with_StableLM/models/med.py", line 858, in forward encoder_outputs = self.encoder( File "/home/anaconda3/envs/videochat/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, *kwargs) File "/home/plischwe/Ask-Anything/video_chat_with_StableLM/models/med.py", line 522, in forward layer_outputs = layer_module( File "/home/anaconda3/envs/videochat/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(args, kwargs) File "/home/plischwe/Ask-Anything/video_chat_with_StableLM/models/med.py", line 438, in forward cross_attention_outputs = self.crossattention( File "/home/anaconda3/envs/videochat/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, *kwargs) File "/home/plischwe/Ask-Anything/video_chat_with_StableLM/models/med.py", line 333, in forward self_outputs = self.self( File "/home/anaconda3/envs/videochat/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(args, **kwargs) File "/home/plischwe/Ask-Anything/video_chat_with_StableLM/models/med.py", line 234, in forward attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) RuntimeError: The size of tensor a (24) must match the size of tensor b (9) at non-singleton dimension 0 File "/home/plischwe/Ask-Anything/video_chat_with_StableLM/models/med.py", line 234, in forward

Any input would be greatly appreciated - thanks.

yinanhe commented 5 months ago

To solve this problem, you can refer to https://github.com/salesforce/BLIP/issues/142#issuecomment-1500815672

The following could be a possible solution:

change attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) in line 234 of /home/plischwe/Ask-Anything/video_chat_with_StableLM/models/med.py to

if key_layer.shape[0] > query_layer.shape[0]:
    key_layer = key_layer[:query_layer.shape[0], :, :, :]
    attention_mask = attention_mask[:query_layer.shape[0], :, :]
    value_layer = value_layer[:query_layer.shape[0], :, :, :]
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))