google-research / pix2struct

Apache License 2.0
604 stars 54 forks source link

how to get Confidence score or probability score of Pix2struct VQA model #36

Open sricharanamarnath opened 1 year ago

sricharanamarnath commented 1 year ago

How do we get the confidence score of predictions for pix2struct model as mentioned in the below code snippet in pred[0], how do we get the prediction scores or probability scores?

FILENAME = "XXX.pdf"
PAGE_NO = 1
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

model = psg.from_pretrained("google/pix2struct-docvqa-base").to(DEVICE)
processor = psp.from_pretrained("google/pix2struct-docvqa-large")

def convert_pdf_to_image(filename, page_no):
return convert_from_path(filename)[page_no-1]

image = convert_pdf_to_image(FILENAME, PAGE_NO)

inputs = processor(images=[image for _ in range(len(questions))], text=questions, return_tensors="pt").to(DEVICE)

pred=model.generate(**inputs)
print(processor.decode(pred[0],skip_special_tokens=True))
Jaiczay commented 1 year ago

Did you find a solution by now?

sricharanamarnath commented 1 year ago

Did you find a solution by now?

Not yet, no luck, tried implementing doens't work.

Jaiczay commented 1 year ago

I found out how to get the a score per token generation (transition_proba):

from PIL import Image
import torch
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor

image = Image.open("image.jpg")
question = 'question?'

model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-infographics-vqa-large")
model.config.vocab_size = 50244
processor = Pix2StructProcessor.from_pretrained("google/pix2struct-infographics-vqa-large")

inputs = processor(images=image, text=question, return_tensors="pt")
predictions = model.generate(**inputs, return_dict_in_generate=True, output_scores=True)

transition_scores = model.compute_transition_scores(predictions.sequences, predictions.scores, normalize_logits=True)
transition_proba = torch.exp(transition_scores)[0]

answer = processor.decode(predictions.sequences[0], skip_special_tokens=True)

But what I have tested so far, the score doesn't help to decide whether the answer is correct or wrong (or can be found in the context)... Maybe it helps you.