Closed ChengjieLi28 closed 5 months ago
@ChengjieLi28 hi,方便发一下你改后的脚本吗?我们尝试本地复现一下你的问题,谢谢!
@ChengjieLi28 hi,方便发一下你改后的脚本吗?我们尝试本地复现一下你的问题,谢谢!
好的,稍后我贴一个最小化复现
@iyuge2 你好,代码如下:
import os
import time
import torch
from PIL import Image
from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessorList, TemperatureLogitsWarper, \
TopPLogitsWarper
import time
MODEL_PATH = "<model_path>"
PIC_PATH = "<picture_path>"
TORCH_TYPE = torch.bfloat16
device = 'cuda:0'
query = "Describe the picture."
image = Image.open(PIC_PATH).convert('RGB')
temperature = 0.6
top_p = 0.9
max_new_tokens = 512
def recur_move_to(item, tgt, criterion_func):
if criterion_func(item):
device_copy = item.to(tgt)
return device_copy
elif isinstance(item, list):
return [recur_move_to(v, tgt, criterion_func) for v in item]
elif isinstance(item, tuple):
return tuple([recur_move_to(v, tgt, criterion_func) for v in item])
elif isinstance(item, dict):
return {k: recur_move_to(v, tgt, criterion_func) for k, v in item.items()}
else:
return item
def collate_fn(features, tokenizer) -> dict:
images = [feature.pop('images', None) for feature in features if 'images' in feature]
tokenizer.pad_token = tokenizer.eos_token
max_length = max(len(feature['input_ids']) for feature in features)
def pad_to_max_length(feature, max_length):
padding_length = max_length - len(feature['input_ids'])
print(f"===padding_length: {padding_length}")
feature['input_ids'] = torch.cat([feature['input_ids'], torch.full((padding_length,), tokenizer.pad_token_id)])
feature['token_type_ids'] = torch.cat([feature['token_type_ids'], torch.zeros(padding_length, dtype=torch.long)])
feature['attention_mask'] = torch.cat([feature['attention_mask'], torch.zeros(padding_length, dtype=torch.long)])
if feature['labels'] is not None:
feature['labels'] = torch.cat([feature['labels'], torch.full((padding_length,), tokenizer.pad_token_id)])
else:
feature['labels'] = torch.full((max_length,), tokenizer.pad_token_id)
return feature
features = [pad_to_max_length(feature, max_length) for feature in features]
batch = {
key: torch.stack([feature[key] for feature in features])
for key in features[0].keys()
}
if images:
batch['images'] = images
return batch
tokenizer = AutoTokenizer.from_pretrained(
MODEL_PATH,
trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
torch_dtype=TORCH_TYPE,
trust_remote_code=True,
device_map=device,
# load_in_4bit=True,
# low_cpu_mem_usage=True
).eval()
input_sample_list = []
input_sample = model.build_conversation_input_ids(tokenizer, query=query, history=[], images=[image], template_version='chat')
input_sample_list.append(input_sample)
input_batch = collate_fn(input_sample_list, tokenizer)
input_batch = recur_move_to(input_batch, device, lambda x: isinstance(x, torch.Tensor))
input_batch = recur_move_to(input_batch, torch.bfloat16, lambda x: isinstance(x, torch.Tensor) and torch.is_floating_point(x))
print(input_batch.keys())
def prepare_logits_processor(
temperature: float, top_p: float
) -> LogitsProcessorList:
processor_list = LogitsProcessorList()
# TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op so we skip two cases.
if temperature >= 1e-5 and temperature != 1.0:
processor_list.append(TemperatureLogitsWarper(temperature))
if 1e-8 <= top_p < 1.0:
processor_list.append(TopPLogitsWarper(top_p))
return processor_list
def _get_token_from_logits(
logits, temperature, top_p
):
logits_processor = prepare_logits_processor(
temperature, top_p
)
last_token_logits = logits_processor(None, logits[0:1, -1, :])[
0
]
probs = torch.softmax(last_token_logits, dim=-1)
indices = torch.multinomial(probs, num_samples=2)
token = indices[0].int().item()
return token
def _get_inference_kws(kv):
batch_size, seq_length, device = (
kv[0][0].shape[0],
kv[0][0].shape[2] + 1,
kv[0][0].device,
)
res = {}
position_ids = torch.full(
(batch_size, 1), fill_value=seq_length - 1, dtype=torch.long, device=device
)
attention_mask = torch.ones(
(batch_size, seq_length), dtype=torch.long, device=device
)
res["attention_mask"] = attention_mask
res["position_ids"] = position_ids
token_type_ids = torch.full(
(batch_size, 1), fill_value=1, dtype=torch.long, device=device
)
res["token_type_ids"] = token_type_ids
return res
@torch.inference_mode()
def exeucte():
# prefill
new_tokens = []
out = model(**input_batch, use_cache=True)
logits = out.logits
kv_cache = out.past_key_values
token = _get_token_from_logits(logits, temperature, top_p)
new_tokens.append(token)
# decode
for idx in range(max_new_tokens):
decode_tokens = [[new_tokens[-1]]]
inf_kws = _get_inference_kws(kv_cache)
out = model(
input_ids=torch.as_tensor(decode_tokens, device=device),
use_cache=True,
past_key_values=kv_cache,
**inf_kws,
)
logits = out.logits
kv_cache = out.past_key_values
token = _get_token_from_logits(logits, temperature, top_p)
new_tokens.append(token)
print(f"Output: {token}")
if token == 128001:
break
res = tokenizer.decode(
new_tokens,
skip_special_tokens=True,
spaces_between_special_tokens=False,
clean_up_tokenization_spaces=True)
print(f"Results: {res}")
if __name__ == '__main__':
exeucte()
说明:
recur_move_to
和 collate_fn
直接从此文件:https://github.com/THUDM/CogVLM2/blob/main/basic_demo/cli_demo_batch_inference.py 中copy而来我这边的运行结果:
query为:Describe the picture.
query为:描述此图
补充:模型用的是THUDM/cogvlm2-llama3-chinese-chat-19B-int4,从modelscope上下载的
@ChengjieLi28 收到,我们抽空看一下,有结果同步你哈
@ChengjieLi28 你好,我这边看了一下。主要问题在position id 的处理上。
seq_length - 1
。position_ids = torch.full(
(batch_size, 1), fill_value=seq_length - 1, dtype=torch.long, device=device
)
下面是我简单改后的代码,跟你代码有差异的地方使用 Fix: ***
进行了提示,我这边初步测试应该是没问题的。不过,由于没有经过详细指标测试,所以仍然建议您通过单步调试的方式,参考官方的 generate
内部实现,重构整体逻辑。
import os
import time
import torch
from PIL import Image
from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessorList, TemperatureLogitsWarper, \
TopPLogitsWarper, TopKLogitsWarper, BitsAndBytesConfig
import time
MODEL_PATH = "" # ckpt
PIC_PATH = "" # image_path
TORCH_TYPE = torch.bfloat16
LANGUAGE_TOKEN_TYPE = 0
VISION_TOKEN_TYPE = 1
device = 'cuda'
query = ""
# query = "Describe this image."
image = Image.open(PIC_PATH).convert('RGB')
temperature = 0.6
top_p = 0.9
top_k = 1
max_new_tokens = 512
def recur_move_to(item, tgt, criterion_func):
if criterion_func(item):
device_copy = item.to(tgt)
return device_copy
elif isinstance(item, list):
return [recur_move_to(v, tgt, criterion_func) for v in item]
elif isinstance(item, tuple):
return tuple([recur_move_to(v, tgt, criterion_func) for v in item])
elif isinstance(item, dict):
return {k: recur_move_to(v, tgt, criterion_func) for k, v in item.items()}
else:
return item
def collate_fn(features, tokenizer) -> dict:
images = [feature.pop('images', None) for feature in features if 'images' in feature]
tokenizer.pad_token = tokenizer.eos_token
max_length = max(len(feature['input_ids']) for feature in features)
def pad_to_max_length(feature, max_length):
padding_length = max_length - len(feature['input_ids'])
print(f"===padding_length: {padding_length}")
feature['input_ids'] = torch.cat([feature['input_ids'], torch.full((padding_length,), tokenizer.pad_token_id)])
feature['token_type_ids'] = torch.cat([feature['token_type_ids'], torch.zeros(padding_length, dtype=torch.long)])
feature['attention_mask'] = torch.cat([feature['attention_mask'], torch.zeros(padding_length, dtype=torch.long)])
if feature['labels'] is not None:
feature['labels'] = torch.cat([feature['labels'], torch.full((padding_length,), tokenizer.pad_token_id)])
else:
feature['labels'] = torch.full((max_length,), tokenizer.pad_token_id)
return feature
features = [pad_to_max_length(feature, max_length) for feature in features]
batch = {
key: torch.stack([feature[key] for feature in features])
for key in features[0].keys()
}
if images:
batch['images'] = images
return batch
tokenizer = AutoTokenizer.from_pretrained(
MODEL_PATH,
trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
torch_dtype=TORCH_TYPE,
trust_remote_code=True,
quantization_config=BitsAndBytesConfig(load_in_4bit=True),
low_cpu_mem_usage=True
).eval()
input_sample_list = []
input_sample = model.build_conversation_input_ids(tokenizer, query=query, history=[], images=[image], template_version='chat')
input_sample_list.append(input_sample)
input_batch = collate_fn(input_sample_list, tokenizer)
input_batch = recur_move_to(input_batch, device, lambda x: isinstance(x, torch.Tensor))
input_batch = recur_move_to(input_batch, torch.bfloat16, lambda x: isinstance(x, torch.Tensor) and torch.is_floating_point(x))
print(input_batch.keys())
def prepare_logits_processor(
temperature: float, top_p: float
) -> LogitsProcessorList:
processor_list = LogitsProcessorList()
# TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op so we skip two cases.
processor_list.append(TopKLogitsWarper(top_k)) # Fix: 建议加上 top_k 限制
if temperature >= 1e-5 and temperature != 1.0:
processor_list.append(TemperatureLogitsWarper(temperature))
if 1e-8 <= top_p < 1.0:
processor_list.append(TopPLogitsWarper(top_p))
return processor_list
def _get_token_from_logits(
logits, temperature, top_p
):
logits_processor = prepare_logits_processor(
temperature, top_p
)
last_token_logits = logits_processor(None, logits[0:1, -1, :])[
0
]
probs = torch.softmax(last_token_logits, dim=-1)
indices = torch.multinomial(probs, num_samples=2)
token = indices[0].int().item()
return token
def build_position_ids(x, attention_mask = None):
# Fix: 参考官方开源代码
if attention_mask is not None:
tmp = x.clone()
tmp[~(attention_mask.bool())] = -1
else:
tmp = x.clone()
# image boi eoi token as LANGUAGE_TOKEN_TYPE
is_boi_eoi = torch.zeros_like(x, dtype=torch.bool)
is_boi_eoi[:, 1:] |= (tmp[:, 1:] == VISION_TOKEN_TYPE) & (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE)
is_boi_eoi[:, 0] |= (tmp[:, 0] == VISION_TOKEN_TYPE)
is_boi_eoi[:, :-1] |= (tmp[:, :-1] == VISION_TOKEN_TYPE) & (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE)
is_boi_eoi[:, -1] |= (tmp[:, -1] == VISION_TOKEN_TYPE)
tmp[is_boi_eoi] = LANGUAGE_TOKEN_TYPE
# final position ids
y = torch.zeros_like(x, dtype=torch.long)
y[:, 1:] = (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE) | ((tmp[:, 1:] == VISION_TOKEN_TYPE) & (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE))
y = y.cumsum(dim=-1)
return y
def _get_inference_kws(kv, position_id):
batch_size, seq_length, device = (
kv[0][0].shape[0],
kv[0][0].shape[2] + 1,
kv[0][0].device,
)
res = {}
position_ids = torch.full(
(batch_size, 1), fill_value=position_id+1, dtype=torch.long, device=device
) # Fix: 修改 position_id 的更新逻辑
attention_mask = torch.ones(
(batch_size, seq_length), dtype=torch.long, device=device
)
res["attention_mask"] = attention_mask
res["position_ids"] = position_ids
token_type_ids = torch.full(
(batch_size, 1), fill_value=0, dtype=torch.long, device=device
)
res["token_type_ids"] = token_type_ids
return res
@torch.inference_mode()
def exeucte():
# prefill
new_tokens = []
# Fix: prefill 阶段之前代码没有传 position_ids
input_batch["position_ids"] = build_position_ids(input_batch['token_type_ids'])
out = model(**input_batch, use_cache=True)
logits = out.logits
kv_cache = out.past_key_values
token = _get_token_from_logits(logits, temperature, top_p)
new_tokens.append(token)
max_position_id = input_batch["position_ids"][0,-1].item()
# decode
for idx in range(max_new_tokens):
decode_tokens = [[new_tokens[-1]]]
inf_kws = _get_inference_kws(kv_cache, max_position_id) # Fix: 传入 max_position_id
out = model(
input_ids=torch.as_tensor(decode_tokens, device=device),
use_cache=True,
past_key_values=kv_cache,
**inf_kws,
)
logits = out.logits
kv_cache = out.past_key_values
max_position_id += 1 # Fix: max_position_id 自增1
token = _get_token_from_logits(logits, temperature, top_p)
new_tokens.append(token)
print(f"Output: {token}")
if token == 128001:
break
res = tokenizer.decode(
new_tokens,
skip_special_tokens=True,
spaces_between_special_tokens=False,
clean_up_tokenization_spaces=True)
print(f"Results: {res}")
if __name__ == '__main__':
exeucte()
@iyuge2 修改position_id的处理方式之后成功了(单个和batch都完美成功),非常感谢!
System Info / 系統信息
Python 3.10 transformers 最新
Who can help? / 谁可以帮助到您?
@zRzRzRzRzRzRzR
Information / 问题信息
Reproduction / 复现过程
前提:我无法使用generate方法进行推理,因为我需要控制推理过程中的细节,比如拿到kv_cache(past_key_values)等。 所以我只能使用call,直接调用。 现象:完全一样的代码,模型用的还是chinese的模型:THUDM/cogvlm2-llama3-chinese-chat-19B-int4
英文如下,一切很正常:
中文:
中文的回答似乎是过早停止了。我打印了过程中decode过程生成的每一个token,发现128001(stop_token_id)过早的出现,于是出现了这个现象。
请问有什么头绪吗?或者有什么建议?或者可能是哪方面的问题?
Expected behavior / 期待表现
中文正常工作