Closed Veason-silverbullet closed 4 months ago
Maybe your template is not aligned with us. I will provide the eval code when I have time this week.
So may I know the released checkpoint is pretrain or SFT stage?
The SFT model.
As there is no update, here I post my inference code (DocVQA ANLS = 65.10) for reference. The code below seems normal.
import os from transformers import AutoTokenizer, CLIPImageProcessor, TextStreamer import torch from PIL import Image from vary.utils.conversation import conv_templates, SeparatorStyle from vary.utils.utils import disable_torch_init, KeywordsStoppingCriteria from vary.model import * from vary.model.plug.transforms import test_transform from eval_utils.anls import evaluate_anls import json DEFAULT_IMAGE_TOKEN = '' DEFAULT_IMAGE_PATCH_TOKEN = ' ' DEFAULT_IM_START_TOKEN = ' ' DEFAULT_IM_END_TOKEN = '' dataset_path = '../datasets/DocVQA' if not os.path.exists('eval-outputs'): os.mkdir('eval-outputs') if __name__ == '__main__': # 1. Load model and configuration disable_torch_init() tokenizer = AutoTokenizer.from_pretrained('../vary-base', trust_remote_code=True) model = varyQwenForCausalLM.from_pretrained('../vary-base', low_cpu_mem_usage=True, device_map='cpu', trust_remote_code=True) model.to(device='cuda', dtype=torch.float16) model.eval() image_processor = CLIPImageProcessor.from_pretrained('../clip-vit-large-patch14', torch_dtype=torch.float16) image_processor_high = test_transform use_im_start_end = True image_token_len = 256 # 2. Evaluation by qa pairs with open(os.path.join(dataset_path, 'val_v1.0_withQT.json'), 'r', encoding='utf-8') as f: data = json.load(f)['data'] results = [] for item in data: question = item['question'].strip() if use_im_start_end: question = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len + DEFAULT_IM_END_TOKEN + question else: question = DEFAULT_IMAGE_TOKEN + '\n' + question conv = conv_templates['mpt'].copy() conv.append_message(conv.roles[0], question) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() inputs = tokenizer([prompt]) image_file = os.path.join(dataset_path, item['image']) image = Image.open(image_file).convert('RGB') image_1 = image.copy() image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] image_tensor_1 = image_processor_high(image_1) input_ids = torch.as_tensor(inputs.input_ids).cuda() # stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 keywords = [stop_str] stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) with torch.no_grad(): with torch.autocast('cuda', dtype=torch.float16): output_ids = model.generate( input_ids, images=[(image_tensor.unsqueeze(0).half().cuda(), image_tensor_1.unsqueeze(0).half().cuda())], do_sample=False, num_beams=1, streamer=streamer, max_new_tokens=256, stopping_criteria=[stopping_criteria] ) answer = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() if answer.endswith(stop_str): answer = answer[:-len(stop_str)] results.append({'questionId': str(item['questionId']), 'answer': answer}) with open('eval-outputs/DocVQA.json', 'w', encoding='utf-8') as f: json.dump(results, f) # 3. Compute ANLS scores anls = evaluate_anls(os.path.join(dataset_path, 'val_v1.0_withQT.json'), 'eval-outputs/DocVQA.json', show_answer_types=False, anls_threshold=0.5) print('ANLS =', anls)
Oh, sorry, I had no time last week. You need to add a little post-process: if answer[-1] == '.': answer = answer[:-1] The period at the end of a sentence will affect 10 points
Yes, the biggest problem stemmed from the period. The DocVQA ANLS rises from 65.10 to 74.86. What a damn of DocVQA evaluation... Many thanks for your assistance.
Dear Authors,
Many thanks for your contribution. I evaluated the Vary-base checkpoint on DocVQA and ChartQA, getting the below performance
The results are lower than reported. May I know if the Vary-base checkpoint is Stage-1 or Stage-2? If that is the Stage-1 checkpoint, I can understand it (and is the Stage-2 checkpoint available?).
Thanks for your attention.