Ucas-HaoranWei / Vary

[ECCV2024] Official code implementation of Vary: Scaling Up the Vision Vocabulary of Large Vision Language Models.
1.65k stars 150 forks source link

Performance on downstream VQA tasks #76

Closed Veason-silverbullet closed 4 months ago

Veason-silverbullet commented 4 months ago

Dear Authors,

Many thanks for your contribution. I evaluated the Vary-base checkpoint on DocVQA and ChartQA, getting the below performance

DocVQA Val (ANLS) ChartQA-human ChartQA-Augmented
65.10 27.60 69.12

The results are lower than reported. May I know if the Vary-base checkpoint is Stage-1 or Stage-2? If that is the Stage-1 checkpoint, I can understand it (and is the Stage-2 checkpoint available?).

Thanks for your attention.

Ucas-HaoranWei commented 4 months ago

Maybe your template is not aligned with us. I will provide the eval code when I have time this week.

Veason-silverbullet commented 4 months ago

So may I know the released checkpoint is pretrain or SFT stage?

Ucas-HaoranWei commented 4 months ago

The SFT model.

Veason-silverbullet commented 4 months ago

As there is no update, here I post my inference code (DocVQA ANLS = 65.10) for reference. The code below seems normal.

import os
from transformers import AutoTokenizer, CLIPImageProcessor, TextStreamer
import torch
from PIL import Image
from vary.utils.conversation import conv_templates, SeparatorStyle
from vary.utils.utils import disable_torch_init, KeywordsStoppingCriteria
from vary.model import *
from vary.model.plug.transforms import test_transform
from eval_utils.anls import evaluate_anls
import json
DEFAULT_IMAGE_TOKEN = ''
DEFAULT_IMAGE_PATCH_TOKEN = ''
DEFAULT_IM_START_TOKEN = ''
DEFAULT_IM_END_TOKEN = ''
dataset_path = '../datasets/DocVQA'
if not os.path.exists('eval-outputs'):
    os.mkdir('eval-outputs')

if __name__ == '__main__':
    # 1. Load model and configuration
    disable_torch_init()
    tokenizer = AutoTokenizer.from_pretrained('../vary-base', trust_remote_code=True)
    model = varyQwenForCausalLM.from_pretrained('../vary-base', low_cpu_mem_usage=True, device_map='cpu', trust_remote_code=True)
    model.to(device='cuda',  dtype=torch.float16)
    model.eval()
    image_processor = CLIPImageProcessor.from_pretrained('../clip-vit-large-patch14', torch_dtype=torch.float16)
    image_processor_high = test_transform
    use_im_start_end = True
    image_token_len = 256

    # 2. Evaluation by qa pairs
    with open(os.path.join(dataset_path, 'val_v1.0_withQT.json'), 'r', encoding='utf-8') as f:
        data = json.load(f)['data']
    results = []
    for item in data:
        question = item['question'].strip()
        if use_im_start_end:
            question = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len + DEFAULT_IM_END_TOKEN  + question
        else:
            question = DEFAULT_IMAGE_TOKEN + '\n' + question
        conv = conv_templates['mpt'].copy()
        conv.append_message(conv.roles[0], question)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()
        inputs = tokenizer([prompt])
        image_file = os.path.join(dataset_path, item['image'])
        image = Image.open(image_file).convert('RGB')
        image_1 = image.copy()
        image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
        image_tensor_1 = image_processor_high(image_1)
        input_ids = torch.as_tensor(inputs.input_ids).cuda()
        # stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
        stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
        keywords = [stop_str]
        stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
        streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

        with torch.no_grad():
            with torch.autocast('cuda', dtype=torch.float16):
                output_ids = model.generate(
                    input_ids,
                    images=[(image_tensor.unsqueeze(0).half().cuda(), image_tensor_1.unsqueeze(0).half().cuda())],
                    do_sample=False,
                    num_beams=1,
                    streamer=streamer,
                    max_new_tokens=256,
                    stopping_criteria=[stopping_criteria]
                )
                answer = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
                if answer.endswith(stop_str):
                    answer = answer[:-len(stop_str)]
                results.append({'questionId': str(item['questionId']), 'answer': answer})
    with open('eval-outputs/DocVQA.json', 'w', encoding='utf-8') as f:
        json.dump(results, f)

    # 3. Compute ANLS scores
    anls = evaluate_anls(os.path.join(dataset_path, 'val_v1.0_withQT.json'), 'eval-outputs/DocVQA.json', show_answer_types=False, anls_threshold=0.5)
    print('ANLS =', anls)
Ucas-HaoranWei commented 4 months ago

Oh, sorry, I had no time last week. You need to add a little post-process: if answer[-1] == '.': answer = answer[:-1] The period at the end of a sentence will affect 10 points

Veason-silverbullet commented 4 months ago

Yes, the biggest problem stemmed from the period. The DocVQA ANLS rises from 65.10 to 74.86. What a damn of DocVQA evaluation... Many thanks for your assistance.