THUDM / CogVLM2

GPT4V-level open-source multi-modal model based on Llama3-8B
Apache License 2.0
2.14k stars 145 forks source link

问题:中英文通过__call__推理结果差距很大 #112

Closed ChengjieLi28 closed 5 months ago

ChengjieLi28 commented 5 months ago

System Info / 系統信息

Python 3.10 transformers 最新

Who can help? / 谁可以帮助到您?

@zRzRzRzRzRzRzR

Information / 问题信息

Reproduction / 复现过程

前提:我无法使用generate方法进行推理,因为我需要控制推理过程中的细节,比如拿到kv_cache(past_key_values)等。 所以我只能使用call,直接调用。 现象:完全一样的代码,模型用的还是chinese的模型:THUDM/cogvlm2-llama3-chinese-chat-19B-int4

英文如下,一切很正常:

image

中文:

image

中文的回答似乎是过早停止了。我打印了过程中decode过程生成的每一个token,发现128001(stop_token_id)过早的出现,于是出现了这个现象。

请问有什么头绪吗?或者有什么建议?或者可能是哪方面的问题?

Expected behavior / 期待表现

中文正常工作

iyuge2 commented 5 months ago

@ChengjieLi28 hi,方便发一下你改后的脚本吗?我们尝试本地复现一下你的问题,谢谢!

ChengjieLi28 commented 5 months ago

@ChengjieLi28 hi,方便发一下你改后的脚本吗?我们尝试本地复现一下你的问题,谢谢!

好的,稍后我贴一个最小化复现

ChengjieLi28 commented 5 months ago

@iyuge2 你好,代码如下:

import os
import time
import torch

from PIL import Image
from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessorList, TemperatureLogitsWarper, \
    TopPLogitsWarper
import time

MODEL_PATH = "<model_path>"
PIC_PATH = "<picture_path>"
TORCH_TYPE = torch.bfloat16
device = 'cuda:0'
query = "Describe the picture."
image = Image.open(PIC_PATH).convert('RGB')
temperature = 0.6
top_p = 0.9
max_new_tokens = 512

def recur_move_to(item, tgt, criterion_func):
    if criterion_func(item):
        device_copy = item.to(tgt)
        return device_copy
    elif isinstance(item, list):
        return [recur_move_to(v, tgt, criterion_func) for v in item]
    elif isinstance(item, tuple):
        return tuple([recur_move_to(v, tgt, criterion_func) for v in item])
    elif isinstance(item, dict):
        return {k: recur_move_to(v, tgt, criterion_func) for k, v in item.items()}
    else:
        return item

def collate_fn(features, tokenizer) -> dict:
    images = [feature.pop('images', None) for feature in features if 'images' in feature]
    tokenizer.pad_token = tokenizer.eos_token
    max_length = max(len(feature['input_ids']) for feature in features)

    def pad_to_max_length(feature, max_length):
        padding_length = max_length - len(feature['input_ids'])
        print(f"===padding_length: {padding_length}")
        feature['input_ids'] = torch.cat([feature['input_ids'], torch.full((padding_length,), tokenizer.pad_token_id)])
        feature['token_type_ids'] = torch.cat([feature['token_type_ids'], torch.zeros(padding_length, dtype=torch.long)])
        feature['attention_mask'] = torch.cat([feature['attention_mask'], torch.zeros(padding_length, dtype=torch.long)])
        if feature['labels'] is not None:
            feature['labels'] = torch.cat([feature['labels'], torch.full((padding_length,), tokenizer.pad_token_id)])
        else:
            feature['labels'] = torch.full((max_length,), tokenizer.pad_token_id)
        return feature

    features = [pad_to_max_length(feature, max_length) for feature in features]
    batch = {
        key: torch.stack([feature[key] for feature in features])
        for key in features[0].keys()
    }

    if images:
        batch['images'] = images

    return batch

tokenizer = AutoTokenizer.from_pretrained(
    MODEL_PATH,
    trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_PATH,
    torch_dtype=TORCH_TYPE,
    trust_remote_code=True,
    device_map=device,
    # load_in_4bit=True,
    # low_cpu_mem_usage=True
).eval()

input_sample_list = []
input_sample = model.build_conversation_input_ids(tokenizer, query=query, history=[], images=[image], template_version='chat')
input_sample_list.append(input_sample)

input_batch = collate_fn(input_sample_list, tokenizer)
input_batch = recur_move_to(input_batch, device, lambda x: isinstance(x, torch.Tensor))
input_batch = recur_move_to(input_batch, torch.bfloat16, lambda x: isinstance(x, torch.Tensor) and torch.is_floating_point(x))
print(input_batch.keys())

def prepare_logits_processor(
    temperature: float, top_p: float
) -> LogitsProcessorList:
    processor_list = LogitsProcessorList()
    # TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op so we skip two cases.
    if temperature >= 1e-5 and temperature != 1.0:
        processor_list.append(TemperatureLogitsWarper(temperature))
    if 1e-8 <= top_p < 1.0:
        processor_list.append(TopPLogitsWarper(top_p))
    return processor_list

def _get_token_from_logits(
    logits, temperature, top_p
):
    logits_processor = prepare_logits_processor(
        temperature, top_p
    )
    last_token_logits = logits_processor(None, logits[0:1, -1, :])[
        0
    ]
    probs = torch.softmax(last_token_logits, dim=-1)
    indices = torch.multinomial(probs, num_samples=2)
    token = indices[0].int().item()
    return token

def _get_inference_kws(kv):
    batch_size, seq_length, device = (
        kv[0][0].shape[0],
        kv[0][0].shape[2] + 1,
        kv[0][0].device,
    )
    res = {}
    position_ids = torch.full(
        (batch_size, 1), fill_value=seq_length - 1, dtype=torch.long, device=device
    )
    attention_mask = torch.ones(
        (batch_size, seq_length), dtype=torch.long, device=device
    )
    res["attention_mask"] = attention_mask
    res["position_ids"] = position_ids
    token_type_ids = torch.full(
        (batch_size, 1), fill_value=1, dtype=torch.long, device=device
    )
    res["token_type_ids"] = token_type_ids
    return res

@torch.inference_mode()
def exeucte():
    # prefill
    new_tokens = []
    out = model(**input_batch, use_cache=True)
    logits = out.logits
    kv_cache = out.past_key_values

    token = _get_token_from_logits(logits, temperature, top_p)
    new_tokens.append(token)

    # decode
    for idx in range(max_new_tokens):
        decode_tokens = [[new_tokens[-1]]]
        inf_kws = _get_inference_kws(kv_cache)
        out = model(
            input_ids=torch.as_tensor(decode_tokens, device=device),
            use_cache=True,
            past_key_values=kv_cache,
            **inf_kws,
        )
        logits = out.logits
        kv_cache = out.past_key_values

        token = _get_token_from_logits(logits, temperature, top_p)
        new_tokens.append(token)
        print(f"Output: {token}")
        if token == 128001:
            break

    res = tokenizer.decode(
            new_tokens,
            skip_special_tokens=True,
            spaces_between_special_tokens=False,
            clean_up_tokenization_spaces=True)
    print(f"Results: {res}")

if __name__ == '__main__':
    exeucte()

说明:

我这边的运行结果: query为:Describe the picture.

image

query为:描述此图

image

补充:模型用的是THUDM/cogvlm2-llama3-chinese-chat-19B-int4,从modelscope上下载的

iyuge2 commented 5 months ago

@ChengjieLi28 收到,我们抽空看一下,有结果同步你哈

iyuge2 commented 5 months ago

@ChengjieLi28 你好,我这边看了一下。主要问题在position id 的处理上。

  1. 在 prefill 阶段,你之前代码没有传 position_id;
  2. 后续每次生成下一个 token 时,下面的操作是错误的。因为我们将 image token 的所有位置都采用了同样的 position_id 值,所以当前 token 的 position_id 并不等于 seq_length - 1
position_ids = torch.full(
        (batch_size, 1), fill_value=seq_length - 1, dtype=torch.long, device=device
    )

下面是我简单改后的代码,跟你代码有差异的地方使用 Fix: *** 进行了提示,我这边初步测试应该是没问题的。不过,由于没有经过详细指标测试,所以仍然建议您通过单步调试的方式,参考官方的 generate 内部实现,重构整体逻辑。

image
import os
import time
import torch

from PIL import Image
from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessorList, TemperatureLogitsWarper, \
    TopPLogitsWarper, TopKLogitsWarper, BitsAndBytesConfig
import time

MODEL_PATH = "" # ckpt
PIC_PATH = "" # image_path
TORCH_TYPE = torch.bfloat16
LANGUAGE_TOKEN_TYPE = 0
VISION_TOKEN_TYPE = 1
device = 'cuda'
query = ""
# query = "Describe this image."
image = Image.open(PIC_PATH).convert('RGB')
temperature = 0.6
top_p = 0.9
top_k = 1
max_new_tokens = 512

def recur_move_to(item, tgt, criterion_func):
    if criterion_func(item):
        device_copy = item.to(tgt)
        return device_copy
    elif isinstance(item, list):
        return [recur_move_to(v, tgt, criterion_func) for v in item]
    elif isinstance(item, tuple):
        return tuple([recur_move_to(v, tgt, criterion_func) for v in item])
    elif isinstance(item, dict):
        return {k: recur_move_to(v, tgt, criterion_func) for k, v in item.items()}
    else:
        return item

def collate_fn(features, tokenizer) -> dict:
    images = [feature.pop('images', None) for feature in features if 'images' in feature]
    tokenizer.pad_token = tokenizer.eos_token
    max_length = max(len(feature['input_ids']) for feature in features)

    def pad_to_max_length(feature, max_length):
        padding_length = max_length - len(feature['input_ids'])
        print(f"===padding_length: {padding_length}")
        feature['input_ids'] = torch.cat([feature['input_ids'], torch.full((padding_length,), tokenizer.pad_token_id)])
        feature['token_type_ids'] = torch.cat([feature['token_type_ids'], torch.zeros(padding_length, dtype=torch.long)])
        feature['attention_mask'] = torch.cat([feature['attention_mask'], torch.zeros(padding_length, dtype=torch.long)])
        if feature['labels'] is not None:
            feature['labels'] = torch.cat([feature['labels'], torch.full((padding_length,), tokenizer.pad_token_id)])
        else:
            feature['labels'] = torch.full((max_length,), tokenizer.pad_token_id)
        return feature

    features = [pad_to_max_length(feature, max_length) for feature in features]
    batch = {
        key: torch.stack([feature[key] for feature in features])
        for key in features[0].keys()
    }

    if images:
        batch['images'] = images

    return batch

tokenizer = AutoTokenizer.from_pretrained(
    MODEL_PATH,
    trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
        MODEL_PATH,
        torch_dtype=TORCH_TYPE,
        trust_remote_code=True,
        quantization_config=BitsAndBytesConfig(load_in_4bit=True),
        low_cpu_mem_usage=True
    ).eval()

input_sample_list = []
input_sample = model.build_conversation_input_ids(tokenizer, query=query, history=[], images=[image], template_version='chat')
input_sample_list.append(input_sample)

input_batch = collate_fn(input_sample_list, tokenizer)
input_batch = recur_move_to(input_batch, device, lambda x: isinstance(x, torch.Tensor))
input_batch = recur_move_to(input_batch, torch.bfloat16, lambda x: isinstance(x, torch.Tensor) and torch.is_floating_point(x))
print(input_batch.keys())

def prepare_logits_processor(
    temperature: float, top_p: float
) -> LogitsProcessorList:
    processor_list = LogitsProcessorList()
    # TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op so we skip two cases.
    processor_list.append(TopKLogitsWarper(top_k)) # Fix: 建议加上 top_k 限制
    if temperature >= 1e-5 and temperature != 1.0:
        processor_list.append(TemperatureLogitsWarper(temperature))
    if 1e-8 <= top_p < 1.0:
        processor_list.append(TopPLogitsWarper(top_p))
    return processor_list

def _get_token_from_logits(
    logits, temperature, top_p
):
    logits_processor = prepare_logits_processor(
        temperature, top_p
    )
    last_token_logits = logits_processor(None, logits[0:1, -1, :])[
        0
    ]
    probs = torch.softmax(last_token_logits, dim=-1)
    indices = torch.multinomial(probs, num_samples=2)
    token = indices[0].int().item()
    return token

def build_position_ids(x, attention_mask = None):
    # Fix: 参考官方开源代码
    if attention_mask is not None:
        tmp = x.clone()
        tmp[~(attention_mask.bool())] = -1
    else:
        tmp = x.clone()
    # image boi eoi token as LANGUAGE_TOKEN_TYPE
    is_boi_eoi = torch.zeros_like(x, dtype=torch.bool)
    is_boi_eoi[:, 1:] |= (tmp[:, 1:] == VISION_TOKEN_TYPE) & (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE)
    is_boi_eoi[:, 0] |= (tmp[:, 0] == VISION_TOKEN_TYPE)
    is_boi_eoi[:, :-1] |= (tmp[:, :-1] == VISION_TOKEN_TYPE) & (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE)
    is_boi_eoi[:, -1] |= (tmp[:, -1] == VISION_TOKEN_TYPE)
    tmp[is_boi_eoi] = LANGUAGE_TOKEN_TYPE
    # final position ids
    y = torch.zeros_like(x, dtype=torch.long)
    y[:, 1:] = (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE) | ((tmp[:, 1:] == VISION_TOKEN_TYPE) & (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE))
    y = y.cumsum(dim=-1)
    return y

def _get_inference_kws(kv, position_id):
    batch_size, seq_length, device = (
        kv[0][0].shape[0],
        kv[0][0].shape[2] + 1,
        kv[0][0].device,
    )
    res = {}
    position_ids = torch.full(
        (batch_size, 1), fill_value=position_id+1, dtype=torch.long, device=device
    ) # Fix: 修改 position_id 的更新逻辑
    attention_mask = torch.ones(
        (batch_size, seq_length), dtype=torch.long, device=device
    )
    res["attention_mask"] = attention_mask
    res["position_ids"] = position_ids
    token_type_ids = torch.full(
        (batch_size, 1), fill_value=0, dtype=torch.long, device=device
    )
    res["token_type_ids"] = token_type_ids
    return res

@torch.inference_mode()
def exeucte():
    # prefill
    new_tokens = []
    # Fix: prefill 阶段之前代码没有传 position_ids
    input_batch["position_ids"] = build_position_ids(input_batch['token_type_ids'])
    out = model(**input_batch, use_cache=True)
    logits = out.logits
    kv_cache = out.past_key_values

    token = _get_token_from_logits(logits, temperature, top_p)
    new_tokens.append(token)
    max_position_id = input_batch["position_ids"][0,-1].item()

    # decode
    for idx in range(max_new_tokens):
        decode_tokens = [[new_tokens[-1]]]
        inf_kws = _get_inference_kws(kv_cache, max_position_id) # Fix: 传入 max_position_id
        out = model(
            input_ids=torch.as_tensor(decode_tokens, device=device),
            use_cache=True,
            past_key_values=kv_cache,
            **inf_kws,
        )
        logits = out.logits
        kv_cache = out.past_key_values
        max_position_id += 1 # Fix: max_position_id 自增1

        token = _get_token_from_logits(logits, temperature, top_p)
        new_tokens.append(token)
        print(f"Output: {token}")
        if token == 128001:
            break

    res = tokenizer.decode(
            new_tokens,
            skip_special_tokens=True,
            spaces_between_special_tokens=False,
            clean_up_tokenization_spaces=True)
    print(f"Results: {res}")

if __name__ == '__main__':
    exeucte()
ChengjieLi28 commented 5 months ago

@iyuge2 修改position_id的处理方式之后成功了(单个和batch都完美成功),非常感谢!