Ucas-HaoranWei / Vary-toy

Official code implementation of Vary-toy (Small Language Model Meets with Reinforced Vision Vocabulary)
565 stars 41 forks source link

难以控制生成语言种类 #36

Open TekhneC opened 4 days ago

TekhneC commented 4 days ago

使用提供的 demo 进行略微修改后,对于这张图(raw_img)测试发现生成结果语言种类不稳定。 image

加载模型

# Model
disable_torch_init()
# model_name = os.path.expanduser("facebook/opt-350m")
tokenizer = AutoTokenizer.from_pretrained("../../autodl-tmp/result/ckpt/checkpoint-5000", trust_remote_code=True,) #训练5000步的模型
model = varyQwenForCausalLM.from_pretrained("../../autodl-tmp/result/ckpt/checkpoint-5000", low_cpu_mem_usage=True,device_map = 'auto')
# model.to(device='cuda',  dtype=torch.bfloat16)

image_processor = CLIPImageProcessor.from_pretrained("../../autodl-tmp/model/clip", torch_dtype=torch.float16,device_map = 'auto')
image_processor_high = BlipImageEvalProcessor(image_size=1024)
use_im_start_end = True
image_token_len = 256
conv_mode = "vqa" #我创建的对话模板

评估代码

def eval_model():
    qs = 'Is there any man?'
    if use_im_start_end:
        qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len + DEFAULT_IM_END_TOKEN  + '\n' + qs
    else:
        qs = DEFAULT_IMAGE_TOKEN + '\n' + qs

    qs1 = 'What game are the men playing?'
    if use_im_start_end:
        qs1 = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len + DEFAULT_IM_END_TOKEN  + '\n' + qs1
    else:
        qs1 = DEFAULT_IMAGE_TOKEN + '\n' + qs1

    # args.conv_mode = conv_mode
    conv_mode = "vqa"
    conv = conv_templates[conv_mode].copy()
    conv.append_message(conv.roles[0], qs)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()
    print(prompt)

    conv1 = conv_templates[conv_mode].copy()
    conv1.append_message(conv.roles[0], qs1)
    conv1.append_message(conv.roles[1],None)
    prompt1 = conv1.get_prompt()
    # print(prompt1)

    transform =  transforms.Compose([transforms.ToTensor()]) 
    inputs = tokenizer([prompt,prompt1], padding=True)

    print(len(inputs['input_ids']))
    inputs = torch.tensor(inputs['input_ids'])
    path = "../../autodl-tmp/data/vqav2/imgs/COCO_test2015_000000262144.jpg"
    with open(path, 'rb') as open_file:
        raw_img = Image.open(open_file).convert('RGB')
        image_1 = raw_img.copy()
    image_tensor = image_processor.preprocess(raw_img, return_tensors='pt')['pixel_values'][0]
    image_tensor_1 = image_processor_high(image_1)
    input_ids = inputs.cuda()

    # stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
    stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
    keywords = [stop_str]
    stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
    # streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

    with torch.autocast("cuda", dtype=torch.bfloat16):
        output_ids = model.generate(
            input_ids,
            images=[(image_tensor.unsqueeze(0).half().cuda(), image_tensor_1.unsqueeze(0).half().cuda()),(image_tensor.unsqueeze(0).half().cuda(), image_tensor_1.unsqueeze(0).half().cuda())],
            do_sample=True,
            num_beams = 1,
            # temperature=0.2,
            # streamer=streamer,
            max_new_tokens=64,
            stopping_criteria=[stopping_criteria]
            )

        print(output_ids.shape)

        outputs = tokenizer.batch_decode(output_ids[:, input_ids.shape[1]:])#

        # conv.messages[-1][-1] = outputs
        out = []
        for o in outputs:
            o = o.strip()
            if o.endswith(stop_str):
                o = o[:o.find(stop_str)]
                o = o.strip()
            out.append(o)

        print(out)

eval_model()

尝试控制

自己创建的对话模板尝试解决问题:

conv_my = Conversation(
        system="""<|im_start|>Give your answer in English.""",
    # system = None,
    roles=("<|im_start|>human\n", "<|im_start|>gpt\n"),
    version="mpt",
    messages=(),
    offset=0,
    sep_style=SeparatorStyle.MPT,
    sep="<|im_end|>",
)

生成结果如下: ['慢慢地,这个人跑向本垒板。', 'The men are playing a game of baseball.'] 我期望模型产生英文。

TekhneC commented 4 days ago

batch = 16 的推理结果: image