Open wubangcai opened 1 month ago
是的,缺少多轮demo
+1
+1
我进一步封装了官方代码成为Qwen2VL
类,自己实现了chat
函数。
用户只需要输入自然语言形式的query,以url、本地路径、base64格式的图片imgs(单图和多图都支持),前文对话history即可,非常简单易用。
copy下面的代码
class Qwen2VL:
def __init__(self, model_path = None, max_new_tokens = 1024, min_pixels = 256*28*28, max_pixels = 1280*28*28):
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
model_path,
torch_dtype="auto",
device_map="auto",
)
self.processor = AutoProcessor.from_pretrained(model_path, min_pixels=min_pixels, max_pixels=max_pixels)
self.gen_config = {
"max_new_tokens": max_new_tokens,
}
def parse_input(self, query=None, imgs=None):
if imgs is None:
messages = [{"role": "user", "content": query}]
return messages
if isinstance(imgs, str):
imgs = [imgs]
content = []
for img in imgs:
content.append({"type": "image", "image": img})
content.append({"type": "text", "text": query})
messages = [{"role": "user", "content": content}]
return messages
def chat(self, query = None, imgs = None, history = None):
if history is None:
history = []
user_query = self.parse_input(query, imgs)
history.extend(user_query)
text = self.processor.apply_chat_template(history, tokenize=False, add_generation_prompt=True, add_vision_id=True)
image_inputs, video_inputs = process_vision_info(history)
inputs = self.processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to("cuda")
generated_ids = self.model.generate(**inputs, **self.gen_config)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
response = self.processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
history.append({"role": "assistant", "content": response})
del inputs, generated_ids, generated_ids_trimmed
torch.cuda.empty_cache()
gc.collect()
return response, history
chat_model = Qwen2VL(model_path="local path/repo id")
history = None response, history = chat_model.chat(query="hello", history=history) print(response, history)
response, history = chat_model.chat(query="please describe the image", imgs=["image_url"], history=history) print(response, history)
# 执行结果
![image](https://github.com/user-attachments/assets/d530d2d1-cf66-45f1-9d39-e582f72a349b)
在使用transformer进行模型推理时,多轮对话的messages应该怎么构造?(类似qwen_vl 中history的参数怎么设置)