haotian-liu / LLaVA

[NeurIPS'23 Oral] Visual Instruction Tuning (LLaVA) built towards GPT-4V level capabilities and beyond.
https://llava.hliu.cc
Apache License 2.0
18.31k stars 2k forks source link

Llava becomes repetitive on OCR tasks #1432

Open prhbrt opened 3 months ago

prhbrt commented 3 months ago

Describe the issue

OCR might not be the target task of Llava, but data is data and I still wanted to make a quick report on this.

I tried OCR on these two images:

patent-smaller-cut

patent-smaller-cut-fraktur

Result:

image

image

Code (copied from the cli file):

import torch
import requests
import argparse
from PIL import Image
from io import BytesIO
from transformers import TextStreamer

from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path
from llava.eval.run_llava import eval_model

disable_torch_init()
model_base = None
# model_path = "liuhaotian/llava-v1.5-7b"
# model_path = "liuhaotian/llava-v1.5-13b-lora"
model_path = "liuhaotian/llava-v1.6-vicuna-13b"
model_name = get_model_name_from_path(model_path)
load_8bit=False
load_4bit=False
device='cuda:0'

if "llama-2" in model_name.lower():
    conv_mode = "llava_llama_2"
elif "mistral" in model_name.lower():
    conv_mode = "mistral_instruct"
elif "v1.6-34b" in model_name.lower():
    conv_mode = "chatml_direct"
elif "v1" in model_name.lower():
    conv_mode = "llava_v1"
elif "mpt" in model_name.lower():
    conv_mode = "mpt"
else:
    conv_mode = "llava_v0"

model_name = get_model_name_from_path(model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, model_base, model_name, load_8bit, load_4bit, device=device)

image_file = "patent-smaller-cut-fraktur.png"
image = Image.open(image_file).convert('RGB')
image_size = image.size

image_tensor = process_images([image], image_processor, model.config)
if type(image_tensor) is list:
  image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
else:
  image_tensor = image_tensor.to(model.device, dtype=torch.float16)

# if model.config.mm_use_im_start_end:
prefix = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n'
# else:
#     prefix = DEFAULT_IMAGE_TOKEN + '\n'

prompt = "Transcribe this page the best you can. It has 3 columns."
max_new_tokens = 4096
temperature = 0

conv = conv_templates[conv_mode].copy()
if "mpt" in model_name.lower():
  roles = ('user', 'assistant')
else:
  roles = conv.roles

conv.append_message(conv.roles[0], prefix + prompt)
conv.append_message(conv.roles[1], None)

prompt = conv.get_prompt()

input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

with torch.inference_mode():
  output_ids = model.generate(
      input_ids,
      images=image_tensor,
      image_sizes=[image_size],
      do_sample=True if temperature > 0 else False,
      temperature=temperature,
      max_new_tokens=max_new_tokens,
      # streamer=streamer,
      use_cache=True)

outputs = tokenizer.decode(output_ids[0]).strip()
conv.messages[-1][-1] = outputs
outputs

Installed llava from github:

commit 3e337ad269da3245643a2724a1d694b5839c37f9 (HEAD -> main, origin/main, origin/HEAD)
Author: ZhaoyangLi <43194342+ZhaoyangLi-nju@users.noreply.github.com>
Date:   Fri Apr 19 03:11:11 2024 +0800

    Update Evaluation.md (#1358)

    update the new path with VizWiz VQA Challenge 2024
Yonggie commented 2 months ago

Happens to me a lot. It'll somehow get into a nuts repeat.