Open prhbrt opened 3 months ago
OCR might not be the target task of Llava, but data is data and I still wanted to make a quick report on this.
I tried OCR on these two images:
Result:
Code (copied from the cli file):
import torch import requests import argparse from PIL import Image from io import BytesIO from transformers import TextStreamer from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN from llava.conversation import conv_templates, SeparatorStyle from llava.model.builder import load_pretrained_model from llava.utils import disable_torch_init from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path from llava.model.builder import load_pretrained_model from llava.mm_utils import get_model_name_from_path from llava.eval.run_llava import eval_model disable_torch_init() model_base = None # model_path = "liuhaotian/llava-v1.5-7b" # model_path = "liuhaotian/llava-v1.5-13b-lora" model_path = "liuhaotian/llava-v1.6-vicuna-13b" model_name = get_model_name_from_path(model_path) load_8bit=False load_4bit=False device='cuda:0' if "llama-2" in model_name.lower(): conv_mode = "llava_llama_2" elif "mistral" in model_name.lower(): conv_mode = "mistral_instruct" elif "v1.6-34b" in model_name.lower(): conv_mode = "chatml_direct" elif "v1" in model_name.lower(): conv_mode = "llava_v1" elif "mpt" in model_name.lower(): conv_mode = "mpt" else: conv_mode = "llava_v0" model_name = get_model_name_from_path(model_path) tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, model_base, model_name, load_8bit, load_4bit, device=device) image_file = "patent-smaller-cut-fraktur.png" image = Image.open(image_file).convert('RGB') image_size = image.size image_tensor = process_images([image], image_processor, model.config) if type(image_tensor) is list: image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor] else: image_tensor = image_tensor.to(model.device, dtype=torch.float16) # if model.config.mm_use_im_start_end: prefix = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' # else: # prefix = DEFAULT_IMAGE_TOKEN + '\n' prompt = "Transcribe this page the best you can. It has 3 columns." max_new_tokens = 4096 temperature = 0 conv = conv_templates[conv_mode].copy() if "mpt" in model_name.lower(): roles = ('user', 'assistant') else: roles = conv.roles conv.append_message(conv.roles[0], prefix + prompt) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device) stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 keywords = [stop_str] streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) with torch.inference_mode(): output_ids = model.generate( input_ids, images=image_tensor, image_sizes=[image_size], do_sample=True if temperature > 0 else False, temperature=temperature, max_new_tokens=max_new_tokens, # streamer=streamer, use_cache=True) outputs = tokenizer.decode(output_ids[0]).strip() conv.messages[-1][-1] = outputs outputs
Installed llava from github:
commit 3e337ad269da3245643a2724a1d694b5839c37f9 (HEAD -> main, origin/main, origin/HEAD) Author: ZhaoyangLi <43194342+ZhaoyangLi-nju@users.noreply.github.com> Date: Fri Apr 19 03:11:11 2024 +0800 Update Evaluation.md (#1358) update the new path with VizWiz VQA Challenge 2024
Happens to me a lot. It'll somehow get into a nuts repeat.
Describe the issue
OCR might not be the target task of Llava, but data is data and I still wanted to make a quick report on this.
I tried OCR on these two images:
Result:
Code (copied from the cli file):
Installed llava from github: