X-PLUG / mPLUG-Owl

mPLUG-Owl: The Powerful Multi-modal Large Language Model Family
https://www.modelscope.cn/studios/damo/mPLUG-Owl
MIT License
2.25k stars 171 forks source link

expected scalar type BFloat16 but found Half #104

Closed hhhhnwl closed 1 year ago

hhhhnwl commented 1 year ago

run video model in 8bit get an error

import torch
from mplug_owl_video.modeling_mplug_owl import MplugOwlForConditionalGeneration
from transformers import AutoTokenizer
from mplug_owl_video.processing_mplug_owl import MplugOwlImageProcessor, MplugOwlProcessor

pretrained_ckpt = '/data1/project/CLIP/mPLUG-Owl/mplug-owl-llama-7b-video'

model = MplugOwlForConditionalGeneration.from_pretrained(
    pretrained_ckpt,
    load_in_8bit=True,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)
image_processor = MplugOwlImageProcessor.from_pretrained(pretrained_ckpt)
tokenizer = AutoTokenizer.from_pretrained(pretrained_ckpt)
processor = MplugOwlProcessor(image_processor, tokenizer)

prompts = [
'''The following is a conversation between a curious human and AI assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.
Human: <|video|>
Human: what is that
AI: ''']

video_list = ["test.mp4']

generate_kwargs = {
    'do_sample': True,
    'top_k': 1,
    'max_length': 512
}
inputs = processor(text=prompts, videos=video_list, num_frames=32, return_tensors='pt')
inputs = {k: v.bfloat16() if v.dtype == torch.float else v for k, v in inputs.items()}
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad():
    res = model.generate(**inputs, **generate_kwargs)
sentence = tokenizer.decode(res.tolist()[0], skip_special_tokens=True)
print(sentence)`

│ 450 │ │ │ return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=sel │ │ 451 │ │ │ │ │ │ │ weight, bias, self.stride, │ │ 452 │ │ │ │ │ │ │ _pair(0), self.dilation, self.groups) │ │ ❱ 453 │ │ return F.conv2d(input, weight, bias, self.stride, │ │ 454 │ │ │ │ │ │ self.padding, self.dilation, self.groups) │ │ 455 │ │ │ 456 │ def forward(self, input: Tensor) -> Tensor: │ ╰──────────────────────────────────────────────────────────────────────────────────────────────────╯ RuntimeError: expected scalar type BFloat16 but found Half

MAGAer13 commented 1 year ago

For video version, we only support bfloat16, not support 8bit.

MAGAer13 commented 1 year ago

8bit is not compatible with bfloat16.