csuhan / OneLLM

[CVPR 2024] OneLLM: One Framework to Align All Modalities with Language
591 stars 32 forks source link

Audio-Video-Text Evaluation Scripts are missing (Table 4 of OneLLM paper) #29

Open vittoriopipoli opened 1 month ago

vittoriopipoli commented 1 month ago

Hi @csuhan,

I recently came across your paper presented at CVPR2024, where you introduced the OneLLM model. I found your work highly interesting and particularly relevant to my research. I am keen to conduct a detailed study on the scenarios where OneLLM processes inputs from multiple modalities, such as the audio-video-text cases described in Table 4 of your paper.

However, upon reviewing the resources available at this repository, I was unable to locate the scripts that handle experiments involving more than two modalities. I was wondering if you could kindly share the code for three-modality cases or guide me on how to proceed in setting up such experiments.

I would greatly appreciate any assistance or guidance you can provide on this matter. Thank you for your time, and I look forward to your response.

qixueweigitbub commented 1 month ago

I have the same request. waiting for a simple demo script to run the model with audible video input.

csuhan commented 5 days ago
import sys
import os
import json
import types
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import multiprocessing as mp
import torch.distributed as dist
from fairscale.nn.model_parallel import initialize as fs_init
from util.misc import default_tensor_type
import torchvision.transforms as transforms
from model.meta import MetaModel
from data.conversation_lib import conv_templates

from data.data_utils import make_audio_features
from data import video_utils

def load_audio(audio_path):
    fbank = make_audio_features(audio_path, mel_bins=128)
    fbank = fbank.transpose(0, 1)[None] #[1, 128, 1024]
    return fbank

def load_video(video_path):
    video_feats = video_utils.load_and_transform_video_data(
        video_path, video_path, clip_duration=1, clips_per_video=5)
    return video_feats[:, :, 0]

T_resized_center_crop = transforms.Compose([
        224, interpolation=transforms.InterpolationMode.BICUBIC
    transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])])

IMAGE_ID = 5473 # VI
AUDIO_ID = 13408 # VII
VIDEO_ID = 15682 # VIII

def MetaModel_generate(
    max_gen_len: int = 32,
    temperature: float = 0.8,
    top_p: float = 0.95,
    modal = ['image'],
    bsz = len(prompts)
    assert bsz == 1
    params = self.llma.params
    # assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)

    prompt_tokens = [self.tokenizer.encode(
        x, bos=True, eos=False) for x in prompts]

    min_prompt_size = min([len(t) for t in prompt_tokens])
    max_prompt_size = max([len(t) for t in prompt_tokens])

    total_len = min(params.max_seq_len, max_gen_len + max_prompt_size)

    tokens = torch.full(
        (bsz, total_len), self.tokenizer.pad_id).cuda().long()
    for k, t in enumerate(prompt_tokens):
        tokens[k, : len(t)] = torch.tensor(t).long()
    input_text_mask = tokens != self.tokenizer.pad_id
    start_pos = min_prompt_size
    prev_pos = 0
    for cur_pos in range(start_pos, total_len):
        logits = self.llma.forward_inference(tokens[:, prev_pos:cur_pos], prev_pos, images if prev_pos == 0 else None, audios if prev_pos == 0 else None, videos if prev_pos == 0 else None, modal=modal)
        if temperature > 0:
            probs = torch.softmax(logits / temperature, dim=-1)
            next_token = self.sample_top_p(probs, top_p)
            next_token = torch.argmax(logits, dim=-1)
        next_token = next_token.reshape(-1)
        # only replace token if prompt has already been generated
        next_token = torch.where(
            input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
        tokens[:, cur_pos] = next_token
        prev_pos = cur_pos

    decoded = []
    for i, t in enumerate(tokens.tolist()):
        # cut to max gen len
        t = t[: len(prompt_tokens[i]) + max_gen_len]
        # cut to eos tok if any
            t = t[: t.index(self.tokenizer.eos_id)]
        except ValueError:
    return decoded

def OneLLM_forward_inference(self, tokens: torch.Tensor, start_pos: int, image=None, audio=None, video=None, modal='image'):
    # modal = modal[0] if isinstance(modal, list) else modal
    _bsz, seqlen = tokens.shape
    assert _bsz == 1

    if start_pos == 0:
        # kv cache will not re-allocate if size is unchanged
    h = self.tok_embeddings(tokens)
    self.freqs_cis = self.freqs_cis.to(h.device)

    if image is None and audio is None and video is None:
        if start_pos == 0:
            self.cache_image_words = 0
            freqs_cis = self.freqs_cis[0: seqlen]
            # if image was not None when start_pos=0,
            # the offset should be added to start_pos within later forward_inference calls
            start_pos = start_pos + self.cache_image_words
            freqs_cis = self.freqs_cis[start_pos: start_pos + seqlen]
        modal_inputs = [image, audio, video]
        modal_ids = [IMAGE_ID, AUDIO_ID, VIDEO_ID]
        modals = ['image', 'audio', 'video']

        for modal_input, modal_id, modal in zip(modal_inputs, modal_ids, modals):
            if modal_input is not None:
                modal_tokens = self.encode_image(modal_input, modal)
                modal_tokens = modal_tokens.reshape(-1, h.shape[-1])
                special_mask = tokens == modal_id
                special_mask = special_mask[:, :, None].repeat(1,1, h.shape[-1])
                modal_tokens = modal_tokens.to(h.device, h.dtype)
                h = h.masked_scatter(special_mask, modal_tokens)
                self.cache_image_words += modal_tokens.shape[0]
        seqlen = h.shape[1]
        freqs_cis = self.freqs_cis[0: seqlen]

    # freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]

    mask = None
    if seqlen > 1:
        mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=tokens.device)
        mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)

    for layer in self.layers:
        h = layer(h, start_pos, freqs_cis, mask)
    h = self.norm(h)
    output = self.output(h[:, -1, :])  # only compute last logits
    return output.float()

class onellm_evaluation(nn.Module):
    def __init__(self, model_path="./OneLLM-7B", image_folder=None, video_folder=None, audio_folder=None):

            backend="nccl", rank=0, world_size=1,

        self.target_dtype = {
            "bf16": torch.bfloat16,
            "fp16": torch.float16
        with default_tensor_type(dtype=self.target_dtype, device="cuda"):
            self.model = MetaModel("onellm", "config/llama2/7B.json", None, "config/llama2/tokenizer.model")

        # replace OneLLM with modified generation function
        self.model.generate = types.MethodType(MetaModel_generate, self.model)
        self.model.llma.forward_inference = types.MethodType(OneLLM_forward_inference, self.model.llma)

        print("Loading pretrained weights ...")
        checkpoint = torch.load(os.path.join(model_path, "consolidated.00-of-01.pth"), map_location='cpu')
        msg = self.model.load_state_dict(checkpoint, strict=False)
        print("load result:\n", msg)

        self.image_folder = image_folder
        self.video_folder = video_folder
        self.audio_folder = audio_folder
        self.question_prompt = "Answer with the option's letter from the given choices directly."

    def evaluate_all(self, image_path, audio_path, video_path, question, options):
        # image_path for list of image path
        # audio_path for list of audio path
        # question for question
        # options for [option_A, option_B, option_C, option_D]

        option_text = "A. " + options[0] + "\n" + "B. " + options[1] + "\n" + "C. " + options[2] + "\n" + "D. " + options[3] + "\n" 
        text = question + "\n" + option_text + self.question_prompt

        for index in range(len(image_path)):
            text = text.replace(f"[img{index+1}]", " " + " ".join([IMAGE_TAG] * NUM_MODAL_TOKEN) + " ") # "Options: A. VI VI VI ... B. VI VI ..." 
        for index in range(len(audio_path)):
            text = text.replace(f'[audio{index+1}]', " " + " ".join([AUDIO_TAG] * NUM_MODAL_TOKEN) + " ")
        for index in range(len(video_path)):
            text = text.replace(f'[video{index+1}]', " " + " ".join([VIDEO_TAG] * NUM_MODAL_TOKEN) + " ")

        prompts = []
        conv = conv_templates["v1"].copy()        
        conv.append_message(conv.roles[0], text)
        conv.append_message(conv.roles[1], None)

        image_tensors = []
        for image_file in image_path:
            image_file = os.path.join(self.image_folder, image_file)
            image = Image.open(image_file).convert('RGB')
            image = T_resized_center_crop(image)
        if len(image_tensors) > 0:
            image_tensors = torch.stack(image_tensors).cuda().to(self.target_dtype)
            image_tensors = None

        audio_tensors = []
        for audio_file in audio_path:
            audio_file = os.path.join(self.audio_folder, audio_file)
            audio = load_audio(audio_file)
        if len(audio_tensors) > 0:
            audio_tensors = torch.stack(audio_tensors).cuda().to(self.target_dtype)
            audio_tensors = None

        video_tensors = []
        for video_file in video_path:
            video_file = os.path.join(self.video_folder, video_file)
            video = load_video(video_file)
        if len(video_tensors) > 0:
            video_tensors = torch.stack(video_tensors).cuda().to(self.target_dtype)
            video_tensors = None

        with torch.cuda.amp.autocast(dtype=self.target_dtype):
            responses = self.model.generate(
            outputs = []
            for response, prompt in zip(responses, prompts):
                response = response[len(prompt):].split('###')[0]
                response = response.strip()
        return outputs[0]

    def evaluate_image_audio_text(self, image_path, audio_path, question, options):
        return self.evaluate_all(image_path, audio_path, [], question, options)

    def evaluate_video_audio_text(self, video_path, audio_path, question, options):
        return self.evaluate_all([], audio_path, video_path, question, options)

if __name__ == "__main__":
    model_path = "multimodal_llama2_7B/llama2-7B_img224-patch16_llama_clip_resampler_moe_bsz512-5120_lr2e-5_warm0.05_clip2_X_v20_finetune_8gpu/epoch_0_iter_000043000"
    onellm = onellm_evaluation(model_path, "examples/imgs", "examples/videos", "examples/audios")

    question = "Please select the image below that best matches the audio: [audio1] from the first image: [img1], the second image: [img2], the third image: [img3] and the fourth image: [img4]."
    # options = ["[img1]", "[img2]", "[img3]", "[img4]"]
    options = ["the first image", "the second image", "the third image", "the fourth image"]

    audio_path = ['dog.wav']
    image_paths = ["dog.jpg", 'cat.jpg', 'bird.jpg', 'rabbit.jpg']

    res = onellm.evaluate_image_audio_text(image_paths, audio_path, question, options)

    import pdb;pdb.set_trace()
csuhan commented 5 days ago

Hi @vittoriopipoli @qixueweigitbub , hope this script can help you for mixed modal input.

xxrbudong commented 1 day ago

Thank you very much for your answer. There are also two flags about self.start_tag [modal] and self.end_tag [modal] in the single-modal inference example. Don't you need to set them in various data inference?