google-research / pix2struct

Apache License 2.0
585 stars 51 forks source link

Unable to inference infographicVQA task #25

Open ShubhamAwasthi1 opened 1 year ago

ShubhamAwasthi1 commented 1 year ago

I am trying to run the inference of the model for infographic vqa task. The instruction mention the cli command for a dummy task and is as follows: python -m pix2struct.example_inference \ --gin_search_paths="pix2struct/configs" \ --gin_file=models/pix2struct.gin \ --gin_file=runs/inference.gin \ --gin_file=sizes/base.gin \ --gin.MIXTURE_OR_TASK_NAME="'dummy_pix2struct'" \ --gin.TASK_FEATURE_LENGTHS="{'inputs': 2048, 'targets': 128}" \ --gin.BATCH_SIZE=1 \ --gin.CHECKPOINT_PATH="'gs://pix2struct-data/textcaps_base/checkpoint_280400'" \ --image=$HOME/test_image.jpg

I have added the task task name, check point and text prompt for vqa task. But they are not in accordance to the requirement. Please provide a correct set of input values to perform the inference for the task.

python -m pix2struct.example_inference \ --gin_search_paths="pix2struct/configs" \ --gin_file=models/pix2struct.gin \ --gin_file=runs/inference.gin \ --gin_file=sizes/base.gin \ --gin.MIXTURE_OR_TASK_NAME="InfographicVQA" \ --gin.TASK_FEATURE_LENGTHS="{'inputs': 2048, 'targets': 128}" \ --gin.BATCH_SIZE=1 \ --gin.CHECKPOINT_PATH="gs://pix2struct-data/infographicvqa_large/checkpoint_182000" \ --image="my_input_image.jpeg" \ --text="What is written on the image of the calendar ?"

NielsRogge commented 1 year ago

Hi,

Inference might be a bit easier now with the HuggingFace integration. All checkpoints are on the hub, so if you want to try the infographicsvqa-large checkpoint, you can do that as follows:

from PIL import Image
import requests
from transformers import AutoProcessor, Pix2StructForConditionalGeneration

processor = AutoProcessor.from_pretrained("google/pix2struct-infographics-vqa-large")
model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-infographics-vqa-large")

url = "https://www.ilankelman.org/stopsigns/australia.jpg"
image = Image.open(requests.get(url, stream=True).raw)

question = "What is written on the image of the calendar?"

inputs = processor(images=image, text=question, return_tensors="pt")

# autoregressive generation
predicted_ids = model.generate(**inputs)
predicted_text = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]