Open The-kamisato opened 7 months ago
Hi,感谢你对我们工作的关注和提问。因为我们默认CogCoM是一个针对图像的多模态模型,所以在训练时没有考虑输入不包含图像的情况。然而,你可以通过将chat.py
中198~228行的代码替换为如下代码,使得模型兼容兼容不包含图像输入的情况:
# if image_position < 5: # no image
# inputs = text_processor.tokenizer([prompt], return_tensors="pt").to(model.parameters().__next__().device)['input_ids'][0]
# # pre_image = 0
# else:
new_prompt = prompt[image_position:] if image_position >= 5 else prompt[image_position+1:]
# new_prompt = prompt[image_position:]
if not torch_image or hasattr(text_processor, 'no_eoi'):
new_prompt = new_prompt.replace(text_processor.tokenizer.eoi, '', 1)
inputs_dic = text_processor(new_prompt)
for k in inputs_dic:
if type(inputs_dic[k]) is torch.Tensor and inputs_dic[k].dtype is not torch.int and inputs_dic[k].dtype is not torch.long:
inputs_dic[k] = inputs_dic[k].to(next(model.parameters()).dtype)
if type(inputs_dic[k]) is torch.Tensor:
inputs_dic[k] = inputs_dic[k].to(next(model.parameters()).device)
inputs = inputs_dic['input_ids'].to(model.parameters().__next__().device)[0]
# pre_image = inputs_dic['pre_image']
seq = torch.cat(
[inputs, torch.tensor([-1]*(max_length-len(inputs)), device=inputs.device)], dim=0
)
strategy = BaseStrategy(temperature=temperature, top_p=top_p, top_k=top_k, end_tokens=[text_processor.tokenizer.eos_token_id],
invalid_slices=invalid_slices, repetition_penalty=repetition_penalty)
get_func = text_processor.get_func(inputs, **inputs_dic) if hasattr(text_processor, 'get_func') else get_masks_and_position_ids_default
if image_position < 5:
# inputs = {}
inputs_dic.pop('input_ids')
inputs = {**inputs_dic}
else:
inputs = {**{'vision_'+k:v for k,v in torch_image.items()}, **{'cross_'+k:v for k,v in cross_image.items()}}
inputs_dic.pop('input_ids')
inputs = {**inputs, **inputs_dic}
然而需要注意的是,我们目前的模型版本在多模态训练阶段没有纯文本样本(在上下文窗口中mask掉),所以经过测试发现在纯文本输入的情况下模型的回复效果较差。需要通过结合纯文本的微调来缓解,关于微调可以参考finetune.py
。
我注意到你们chat.py文件里面219行: get_func = text_processor.get_func(inputs, **inputs_dic) if hasattr(text_processor, 'get_func') else get_masks_and_position_ids_default 可是如果我一开始没有输入图,image_position < 5,那么inputs_dic不会被赋值text_processor(new_prompt) (205行),就会报错“在变量定义之前使用” 请问这个怎么解决啊,谢谢