X-PLUG / mPLUG-Owl

mPLUG-Owl: The Powerful Multi-modal Large Language Model Family
https://www.modelscope.cn/studios/damo/mPLUG-Owl
MIT License
2.25k stars 171 forks source link

Fine-tune customized video dataset based on mPLUG_owl. #130

Closed hcwei13 closed 1 year ago

hcwei13 commented 1 year ago

Thank you for the outstanding work of the team. I hope to fine-tune customized video dataset based on mPLUG_owl. I attempted to modify the image-based fine-tuning code, but encountered an error: RuntimeError: "conv_depthwise3d" not implemented for 'BFloat16'. I suspect this might be because the model is initially loaded onto the CPU. I tried changing 'bf16' to False, which resulted in 'video_query_output' containing NaN values. I have also downloaded temporary weights to replace those on the HuggingFace repository, but it had no effect. This issue has been bothering me for a while, and I hope you could provide some suggestions. Thanks!

hcwei13 commented 1 year ago

@MAGAer13

MAGAer13 commented 1 year ago

You can convert the input of depth-conv3d as fp16, then convert it back

zhouwei5113 commented 1 year ago

Thank you for the outstanding work of the team. I hope to fine-tune customized video dataset based on mPLUG_owl. I attempted to modify the image-based fine-tuning code, but encountered an error: RuntimeError: "conv_depthwise3d" not implemented for 'BFloat16'. I suspect this might be because the model is initially loaded onto the CPU. I tried changing 'bf16' to False, which resulted in 'video_query_output' containing NaN values. I have also downloaded temporary weights to replace those on the HuggingFace repository, but it had no effect. This issue has been bothering me for a while, and I hope you could provide some suggestions. Thanks!

@hcwei13 Did you solve the issue?

Hritikbansal commented 1 year ago

@LukeForeverYoung @MAGAer13 I had an impression that https://github.com/X-PLUG/mPLUG-Owl/blob/main/mplug_owl_video/modeling_mplug_owl.py#L214 already uses the solution you suggested above since the bfloat input is changed to half manually here, but the issue still persists for me when i run this code for finetuning.

@hcwei13 @zhouwei5113 is this issue fixed for you or have you found a work around?

hcwei13 commented 1 year ago

Unfortunately, this problem still exists. @Hritikbansal @zhouwei5113 @LukeForeverYoung @MAGAer13

Hritikbansal commented 1 year ago

Hi @hcwei13 , I fixed the problem as follows:

x = self.down_proj(x)
_device = x.device
self = self.to('cpu') # hack: cpu offloading since bfloat16 on gpu gives error with conv_depthwise3d but works with cpu
x = x.to('cpu')
x = self.conv(x)
self = self.to(_device)
x = x.to(_device)
x = self.activation_func(x)

my change may have some latency related repercussions since I perform cpu offloading but it makes the code work on GPU/multi-GPU setup.

hcwei13 commented 1 year ago

It works!!! thanks!!! @Hritikbansal

GarrettLee commented 11 months ago

Thank you for the outstanding work of the team. I hope to fine-tune customized video dataset based on mPLUG_owl. I attempted to modify the image-based fine-tuning code, but encountered an error: RuntimeError: "conv_depthwise3d" not implemented for 'BFloat16'. I suspect this might be because the model is initially loaded onto the CPU. I tried changing 'bf16' to False, which resulted in 'video_query_output' containing NaN values. I have also downloaded temporary weights to replace those on the HuggingFace repository, but it had no effect. This issue has been bothering me for a while, and I hope you could provide some suggestions. Thanks!

Hi, can you share your code? will be appreciate