Open insafim opened 3 weeks ago
`import os import sys import json import torch import yaml import re import ast from PIL import Image from tqdm import tqdm from transformers import Qwen2VLForConditionalGeneration, AutoProcessor from qwen_vl_utils import process_vision_info # pip install qwen-vl-utils import time
if len(sys.argv) == 2: PROMPT = sys.argv[1] else: PROMPT = 'direct1'
MODEL = "Qwen/Qwen2-VL-7B-Instruct" SETTING = 'standard'
MODEL_ID = "Qwen2-vl"
PROMPTS_FILE = "Prompts/prompts_mmmu-pro.yaml" LOCAL_DATA_PATH = "Datasets/MMMU-Pro/MMMU-Pro_standard_4options.json" IMAGE_FOLDER = "Datasets/MMMU-Pro/Images-standard" OUTPUT_JSONPATH = f"Results/mmmu-pro{MODELID}{SETTING}_{PROMPT}.json" MAX_RETRY = 3 DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device selected: {DEVICE}")
model = Qwen2VLForConditionalGeneration.from_pretrained( MODEL, torch_dtype=torch.float16, device_map="auto" ).to(DEVICE)
min_pixels = 256 28 28 max_pixels = 1280 28 28 processor = AutoProcessor.from_pretrained( MODEL, min_pixels=min_pixels, max_pixels=max_pixels, torch_dtype="auto", device_map="auto" )
with open(PROMPTS_FILE, "r") as file: prompt_config = yaml.safe_load(file)
def replace_images_tokens(input_string): for i in range(1, 8): question_text = f"<image {i}>" query_text = "[image]" input_string = input_string.replace(question_text, query_text) return input_string
def parse_options(options): option_letters = [chr(ord("A") + i) for i in range(len(options))] choices_str = "\n".join([f"{option_letter}. {option}" for option_letter, option in zip(option_letters, options)]) return choices_str
def construct_prompt(doc): question = doc["question"] parsed_options = parse_options(ast.literal_eval(str(doc["options"]))) prompt = prompt_config[SETTING][PROMPT] answer_handler = prompt_config[SETTING][ANSWER_HANDLER] question = f"{question}\n{parsed_options}\n{prompt}"
return question
def mmmu_doc_to_text(doc): question = construct_prompt(doc) return replace_images_tokens(question)
def origin_mmmu_doc_to_visual(doc): visual = [] print(f"Extracting images for doc id: {doc.get('id', 'Unknown')}") for i in range(1, 8): imagefilename = doc.get(f'image{i}') if image_filename: image_path = os.path.join(IMAGE_FOLDER, image_filename) if os.path.exists(image_path): print(f"Found image at {image_path}") visual.append(f"{image_path}") else: print(f"Image {image_filename} not found at {image_path}") return visual
def vision_mmmu_doc_to_visual(doc): image_filename = doc.get('image') if image_filename: image_path = os.path.join(IMAGE_FOLDER, image_filename) if os.path.exists(image_path): print(f"Found image at {image_path}") return [f"{image_path}"] else: print(f"Image {image_filename} not found at {image_path}") return []
def process_prompt(data): if SETTING == 'standard': prompt = mmmu_doc_to_text(data) images = origin_mmmu_doc_to_visual(data) elif SETTING == 'vision': prompt = prompt_config['vision'] images = vision_mmmu_doc_to_visual(data)
conversation_content = [{"type": "text", "text": prompt}]
for img_path in images:
conversation_content.append({"type": "image", "image": img_path})
return (prompt, conversation_content)
def initialize_json(file_path): if not os.path.exists(file_path): print(f"Initializing new JSON file at: {file_path}") with open(file_path, 'w', encoding='utf-8') as f: json.dump([], f, ensure_ascii=False, indent=4) else: print(f"JSON file already exists at: {file_path}")
def load_existing_data(file_path): try: with open(file_path, 'r', encoding='utf-8') as f: existing_data = json.load(f) return existing_data except Exception as e: print(f"Error loading existing data: {e}. Starting with an empty dataset.") return []
def update_json(file_path, new_entry): try: with open(file_path, 'r+', encoding='utf-8') as f: data = json.load(f) data.append(new_entry) f.seek(0) json.dump(data, f, ensure_ascii=False, indent=4) f.truncate() print(f"Updated JSON file with new entry id: {new_entry.get('id', 'Unknown')}") except Exception as e: print(f"Error updating JSON file with new entry: {e}")
def run_and_save(): initialize_json(OUTPUT_JSON_PATH) existing_data = load_existing_data(OUTPUT_JSON_PATH) processed_ids = {entry['id'] for entry in existing_data}
try:
print(f"Loading dataset from: {LOCAL_DATA_PATH}")
with open(LOCAL_DATA_PATH, 'r', encoding='utf-8') as json_file:
dataset = json.load(json_file)
except Exception as e:
print(f"Error loading dataset: {e}")
sys.exit(1)
for idx, data in enumerate(tqdm(dataset, desc="Processing dataset")):
entry_id = data.get('id', 'Unknown')
if entry_id in processed_ids:
print(f"Skipping already processed entry id: {entry_id}")
continue
prompt, conversation_content = process_prompt(data)
messages = [{"role": "user", "content": conversation_content}]
try:
print(f"Preparing input for model inference with doc id: {entry_id}")
# text = processor.apply_chat_template(messages, add_generation_prompt=True, add_vision_id=True)
text = processor.apply_chat_template(messages, add_generation_prompt=True)
print(f"Text after applying chat template: {text[:100]}...")
image_inputs, video_inputs = process_vision_info(messages)
print(f"Image inputs: {image_inputs}, Video inputs: {video_inputs}")
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
).to(DEVICE)
except Exception as e:
print(f"Error while processing prompt for id {entry_id}: {str(e)}")
data['response'] = ''
update_json(OUTPUT_JSON_PATH, data)
continue
decoded_output = ""
retry_count = 0
# while not decoded_output and retry_count < MAX_RETRY:
try:
output = model.generate(**inputs, max_new_tokens=1024, return_dict_in_generate=True, output_hidden_states=True)
generated_tokens = output.sequences[:, inputs['input_ids'].shape[-1]:]
decoded_output = processor.decode(generated_tokens[0], skip_special_tokens=True)
if not decoded_output:
retry_count += 1
except Exception as e:
retry_count += 1
data['response'] = decoded_output if decoded_output else ''
update_json(OUTPUT_JSON_PATH, data)
def main(): start_time = time.time() # Start timing run_and_save() end_time = time.time() # End timing total_time = (end_time - start_time) / 60 # Convert to minutes print(f"\nTotal processing time: {total_time:.2f} minutes")
if name == 'main': main() `
Can you please provide your inference code for Qwen2-VL-7B model. I am getting only 41.3% for the standard-4 choices case.
Below is my inference code.