Yuliang-Liu / Monkey

【CVPR 2024 Highlight】Monkey (LMM): Image Resolution and Text Label Are Important Things for Large Multi-modal Models
MIT License
1.82k stars 128 forks source link

Textmonkey有推理代码吗,为什么web demo运行起来不回答 #94

Closed zhangxilin1 closed 5 months ago

zhangxilin1 commented 5 months ago

Textmonkey用monkey-chat的推理代码运行会报错RuntimeError: The expanded size of the tensor (1280) must match the existing size (768) at non-singleton dimension 0. Target sizes: [1280, 4096]. Tensor sizes: [768, 4096],Textmonkey网页版运行终端不报错,但是上传图片问题后不回答,一直在加载

echo840 commented 5 months ago

您好,“Textmonkey网页版运行终端不报错,但是上传图片问题后不回答,一直在加载”这个可能是由于您的网络问题。

您也可以使用下面的推理代码:

from monkey_model.modeling_textmonkey import TextMonkeyLMHeadModel
from monkey_model.tokenization_qwen import QWenTokenizer
from monkey_model.configuration_monkey import MonkeyConfig

if __name__ ==  "__main__":
    checkpoint_path = ""
    input_image = ""
    input_str = "OCR with grounding:"
    device_map = "cuda"
    # Create model
    config = MonkeyConfig.from_pretrained(
            checkpoint_path,
            trust_remote_code=True,
        )
    model = TextMonkeyLMHeadModel.from_pretrained(checkpoint_path,
        config=config,
        device_map=device_map, trust_remote_code=True).eval()
    tokenizer = QWenTokenizer.from_pretrained(checkpoint_path,
                                                trust_remote_code=True)
    tokenizer.padding_side = 'left'
    tokenizer.pad_token_id = tokenizer.eod_id
    tokenizer.IMG_TOKEN_SPAN = config.visual["n_queries"]

    input_str = f"<img>{input_image}</img> {input_str}"
    input_ids = tokenizer(input_str, return_tensors='pt', padding='longest')

    attention_mask = input_ids.attention_mask
    input_ids = input_ids.input_ids

    pred = model.generate(
    input_ids=input_ids.cuda(),
    attention_mask=attention_mask.cuda(),
    do_sample=False,
    num_beams=1,
    max_new_tokens=2048,
    min_new_tokens=1,
    length_penalty=1,
    num_return_sequences=1,
    output_hidden_states=True,
    use_cache=True,
    pad_token_id=tokenizer.eod_id,
    eos_token_id=tokenizer.eod_id,
    )
    response = tokenizer.decode(pred[0][input_ids.size(1):].cpu(), skip_special_tokens=False).strip()
    print(response)