THUDM / CogCoM

Other
146 stars 9 forks source link

有个bug运行不起来 #8

Open The-kamisato opened 7 months ago

The-kamisato commented 7 months ago

我注意到你们chat.py文件里面219行: get_func = text_processor.get_func(inputs, **inputs_dic) if hasattr(text_processor, 'get_func') else get_masks_and_position_ids_default 可是如果我一开始没有输入图,image_position < 5,那么inputs_dic不会被赋值text_processor(new_prompt) (205行),就会报错“在变量定义之前使用” 请问这个怎么解决啊,谢谢

qijimrc commented 6 months ago

Hi,感谢你对我们工作的关注和提问。因为我们默认CogCoM是一个针对图像的多模态模型,所以在训练时没有考虑输入不包含图像的情况。然而,你可以通过将chat.py中198~228行的代码替换为如下代码,使得模型兼容兼容不包含图像输入的情况:

        # if image_position < 5: # no image
        #     inputs = text_processor.tokenizer([prompt], return_tensors="pt").to(model.parameters().__next__().device)['input_ids'][0]
        #     # pre_image = 0
        # else:
        new_prompt = prompt[image_position:] if image_position >= 5 else prompt[image_position+1:]
        # new_prompt = prompt[image_position:]
        if not torch_image or hasattr(text_processor, 'no_eoi'):
            new_prompt = new_prompt.replace(text_processor.tokenizer.eoi, '', 1)
        inputs_dic = text_processor(new_prompt)
        for k in inputs_dic:
            if type(inputs_dic[k]) is torch.Tensor and inputs_dic[k].dtype is not torch.int and inputs_dic[k].dtype is not torch.long:
                inputs_dic[k] = inputs_dic[k].to(next(model.parameters()).dtype)
            if type(inputs_dic[k]) is torch.Tensor:
                inputs_dic[k] = inputs_dic[k].to(next(model.parameters()).device)
        inputs = inputs_dic['input_ids'].to(model.parameters().__next__().device)[0]
        # pre_image = inputs_dic['pre_image']

        seq = torch.cat(
            [inputs, torch.tensor([-1]*(max_length-len(inputs)), device=inputs.device)], dim=0
        )
        strategy = BaseStrategy(temperature=temperature, top_p=top_p, top_k=top_k, end_tokens=[text_processor.tokenizer.eos_token_id],
                                invalid_slices=invalid_slices, repetition_penalty=repetition_penalty)
        get_func = text_processor.get_func(inputs, **inputs_dic) if hasattr(text_processor, 'get_func') else get_masks_and_position_ids_default
        if image_position < 5:
            # inputs = {}
            inputs_dic.pop('input_ids')
            inputs = {**inputs_dic}
        else:
            inputs = {**{'vision_'+k:v for k,v in torch_image.items()}, **{'cross_'+k:v for k,v in cross_image.items()}}
            inputs_dic.pop('input_ids')
            inputs = {**inputs, **inputs_dic}

然而需要注意的是,我们目前的模型版本在多模态训练阶段没有纯文本样本(在上下文窗口中mask掉),所以经过测试发现在纯文本输入的情况下模型的回复效果较差。需要通过结合纯文本的微调来缓解,关于微调可以参考finetune.py