MMMU-Benchmark / MMMU

This repo contains evaluation code for the paper "MMMU: A Massive Multi-discipline Multimodal Understanding and Reasoning Benchmark for Expert AGI"
https://mmmu-benchmark.github.io/
Apache License 2.0
361 stars 26 forks source link

Qwen2-VL-7B Inference Code #42

Open insafim opened 3 weeks ago

insafim commented 3 weeks ago

Can you please provide your inference code for Qwen2-VL-7B model. I am getting only 41.3% for the standard-4 choices case.

Below is my inference code.

insafim commented 3 weeks ago

`import os import sys import json import torch import yaml import re import ast from PIL import Image from tqdm import tqdm from transformers import Qwen2VLForConditionalGeneration, AutoProcessor from qwen_vl_utils import process_vision_info # pip install qwen-vl-utils import time

Configuration

if len(sys.argv) == 2: PROMPT = sys.argv[1] else: PROMPT = 'direct1'

MODEL = "Qwen/Qwen2-VL-7B-Instruct" SETTING = 'standard'

MODEL_ID = "Qwen2-vl"

Define file paths and other constants

PROMPTS_FILE = "Prompts/prompts_mmmu-pro.yaml" LOCAL_DATA_PATH = "Datasets/MMMU-Pro/MMMU-Pro_standard_4options.json" IMAGE_FOLDER = "Datasets/MMMU-Pro/Images-standard" OUTPUT_JSONPATH = f"Results/mmmu-pro{MODELID}{SETTING}_{PROMPT}.json" MAX_RETRY = 3 DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Device selected: {DEVICE}")

Model and Processor Loading

model = Qwen2VLForConditionalGeneration.from_pretrained( MODEL, torch_dtype=torch.float16, device_map="auto" ).to(DEVICE)

processor = AutoProcessor.from_pretrained(MODEL)

min_pixels = 256 28 28 max_pixels = 1280 28 28 processor = AutoProcessor.from_pretrained( MODEL, min_pixels=min_pixels, max_pixels=max_pixels, torch_dtype="auto", device_map="auto" )

Load prompt configuration

with open(PROMPTS_FILE, "r") as file: prompt_config = yaml.safe_load(file)

Helper functions

def replace_images_tokens(input_string): for i in range(1, 8): question_text = f"<image {i}>" query_text = "[image]" input_string = input_string.replace(question_text, query_text) return input_string

def parse_options(options): option_letters = [chr(ord("A") + i) for i in range(len(options))] choices_str = "\n".join([f"{option_letter}. {option}" for option_letter, option in zip(option_letters, options)]) return choices_str

def construct_prompt(doc): question = doc["question"] parsed_options = parse_options(ast.literal_eval(str(doc["options"]))) prompt = prompt_config[SETTING][PROMPT] answer_handler = prompt_config[SETTING][ANSWER_HANDLER] question = f"{question}\n{parsed_options}\n{prompt}"

return question

def mmmu_doc_to_text(doc): question = construct_prompt(doc) return replace_images_tokens(question)

def origin_mmmu_doc_to_visual(doc): visual = [] print(f"Extracting images for doc id: {doc.get('id', 'Unknown')}") for i in range(1, 8): imagefilename = doc.get(f'image{i}') if image_filename: image_path = os.path.join(IMAGE_FOLDER, image_filename) if os.path.exists(image_path): print(f"Found image at {image_path}") visual.append(f"{image_path}") else: print(f"Image {image_filename} not found at {image_path}") return visual

def vision_mmmu_doc_to_visual(doc): image_filename = doc.get('image') if image_filename: image_path = os.path.join(IMAGE_FOLDER, image_filename) if os.path.exists(image_path): print(f"Found image at {image_path}") return [f"{image_path}"] else: print(f"Image {image_filename} not found at {image_path}") return []

def process_prompt(data): if SETTING == 'standard': prompt = mmmu_doc_to_text(data) images = origin_mmmu_doc_to_visual(data) elif SETTING == 'vision': prompt = prompt_config['vision'] images = vision_mmmu_doc_to_visual(data)

conversation_content = [{"type": "text", "text": prompt}]

for img_path in images:
    conversation_content.append({"type": "image", "image": img_path})

return (prompt, conversation_content)

def initialize_json(file_path): if not os.path.exists(file_path): print(f"Initializing new JSON file at: {file_path}") with open(file_path, 'w', encoding='utf-8') as f: json.dump([], f, ensure_ascii=False, indent=4) else: print(f"JSON file already exists at: {file_path}")

def load_existing_data(file_path): try: with open(file_path, 'r', encoding='utf-8') as f: existing_data = json.load(f) return existing_data except Exception as e: print(f"Error loading existing data: {e}. Starting with an empty dataset.") return []

def update_json(file_path, new_entry): try: with open(file_path, 'r+', encoding='utf-8') as f: data = json.load(f) data.append(new_entry) f.seek(0) json.dump(data, f, ensure_ascii=False, indent=4) f.truncate() print(f"Updated JSON file with new entry id: {new_entry.get('id', 'Unknown')}") except Exception as e: print(f"Error updating JSON file with new entry: {e}")

def run_and_save(): initialize_json(OUTPUT_JSON_PATH) existing_data = load_existing_data(OUTPUT_JSON_PATH) processed_ids = {entry['id'] for entry in existing_data}

try:
    print(f"Loading dataset from: {LOCAL_DATA_PATH}")
    with open(LOCAL_DATA_PATH, 'r', encoding='utf-8') as json_file:
        dataset = json.load(json_file)
except Exception as e:
    print(f"Error loading dataset: {e}")
    sys.exit(1)

for idx, data in enumerate(tqdm(dataset, desc="Processing dataset")):
    entry_id = data.get('id', 'Unknown')
    if entry_id in processed_ids:
        print(f"Skipping already processed entry id: {entry_id}")
        continue

    prompt, conversation_content = process_prompt(data)
    messages = [{"role": "user", "content": conversation_content}]

    try:
        print(f"Preparing input for model inference with doc id: {entry_id}")
        # text = processor.apply_chat_template(messages, add_generation_prompt=True, add_vision_id=True)
        text = processor.apply_chat_template(messages, add_generation_prompt=True)
        print(f"Text after applying chat template: {text[:100]}...")

        image_inputs, video_inputs = process_vision_info(messages)
        print(f"Image inputs: {image_inputs}, Video inputs: {video_inputs}")

        inputs = processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt",
        ).to(DEVICE)

    except Exception as e:
        print(f"Error while processing prompt for id {entry_id}: {str(e)}")
        data['response'] = ''
        update_json(OUTPUT_JSON_PATH, data)
        continue

    decoded_output = ""
    retry_count = 0

    # while not decoded_output and retry_count < MAX_RETRY:
    try:

        output = model.generate(**inputs, max_new_tokens=1024, return_dict_in_generate=True, output_hidden_states=True)
        generated_tokens = output.sequences[:, inputs['input_ids'].shape[-1]:]
        decoded_output = processor.decode(generated_tokens[0], skip_special_tokens=True)

        if not decoded_output:
            retry_count += 1

    except Exception as e:
        retry_count += 1

    data['response'] = decoded_output if decoded_output else ''
    update_json(OUTPUT_JSON_PATH, data)

def main(): start_time = time.time() # Start timing run_and_save() end_time = time.time() # End timing total_time = (end_time - start_time) / 60 # Convert to minutes print(f"\nTotal processing time: {total_time:.2f} minutes")

if name == 'main': main() `