csuhan / OneLLM

[CVPR 2024] OneLLM: One Framework to Align All Modalities with Language
Other
591 stars 32 forks source link

Audio-Video-Text Evaluation Scripts are missing (Table 4 of OneLLM paper) #29

Open vittoriopipoli opened 1 month ago

vittoriopipoli commented 1 month ago

Hi @csuhan,

I recently came across your paper presented at CVPR2024, where you introduced the OneLLM model. I found your work highly interesting and particularly relevant to my research. I am keen to conduct a detailed study on the scenarios where OneLLM processes inputs from multiple modalities, such as the audio-video-text cases described in Table 4 of your paper.

However, upon reviewing the resources available at this repository, I was unable to locate the scripts that handle experiments involving more than two modalities. I was wondering if you could kindly share the code for three-modality cases or guide me on how to proceed in setting up such experiments.

I would greatly appreciate any assistance or guidance you can provide on this matter. Thank you for your time, and I look forward to your response.

qixueweigitbub commented 1 month ago

I have the same request. waiting for a simple demo script to run the model with audible video input.

csuhan commented 5 days ago
import sys
sys.path.append('./')
import os
import json
import types
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import multiprocessing as mp
import torch.distributed as dist
from fairscale.nn.model_parallel import initialize as fs_init
from util.misc import default_tensor_type
import torchvision.transforms as transforms
from model.meta import MetaModel
from data.conversation_lib import conv_templates

from data.data_utils import make_audio_features
from data import video_utils

def load_audio(audio_path):
    fbank = make_audio_features(audio_path, mel_bins=128)
    fbank = fbank.transpose(0, 1)[None] #[1, 128, 1024]
    return fbank

def load_video(video_path):
    video_feats = video_utils.load_and_transform_video_data(
        video_path, video_path, clip_duration=1, clips_per_video=5)
    return video_feats[:, :, 0]

T_resized_center_crop = transforms.Compose([
    transforms.Resize(
        224, interpolation=transforms.InterpolationMode.BICUBIC
    ),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])])

IMAGE_ID = 5473 # VI
IMAGE_TAG = "VI"
AUDIO_ID = 13408 # VII
AUDIO_TAG = "VII"
VIDEO_ID = 15682 # VIII
VIDEO_TAG = "VIII"
NUM_MODAL_TOKEN = 30

def MetaModel_generate(
    self,
    prompts,
    images=None,
    audios=None,
    videos=None,
    max_gen_len: int = 32,
    temperature: float = 0.8,
    top_p: float = 0.95,
    modal = ['image'],
):
    bsz = len(prompts)
    assert bsz == 1
    params = self.llma.params
    # assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)

    prompt_tokens = [self.tokenizer.encode(
        x, bos=True, eos=False) for x in prompts]

    min_prompt_size = min([len(t) for t in prompt_tokens])
    max_prompt_size = max([len(t) for t in prompt_tokens])

    total_len = min(params.max_seq_len, max_gen_len + max_prompt_size)

    tokens = torch.full(
        (bsz, total_len), self.tokenizer.pad_id).cuda().long()
    for k, t in enumerate(prompt_tokens):
        tokens[k, : len(t)] = torch.tensor(t).long()
    input_text_mask = tokens != self.tokenizer.pad_id
    start_pos = min_prompt_size
    prev_pos = 0
    for cur_pos in range(start_pos, total_len):
        logits = self.llma.forward_inference(tokens[:, prev_pos:cur_pos], prev_pos, images if prev_pos == 0 else None, audios if prev_pos == 0 else None, videos if prev_pos == 0 else None, modal=modal)
        if temperature > 0:
            probs = torch.softmax(logits / temperature, dim=-1)
            next_token = self.sample_top_p(probs, top_p)
        else:
            next_token = torch.argmax(logits, dim=-1)
        next_token = next_token.reshape(-1)
        # only replace token if prompt has already been generated
        next_token = torch.where(
            input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
        )
        print(next_token)
        tokens[:, cur_pos] = next_token
        prev_pos = cur_pos

    decoded = []
    for i, t in enumerate(tokens.tolist()):
        # cut to max gen len
        t = t[: len(prompt_tokens[i]) + max_gen_len]
        # cut to eos tok if any
        try:
            t = t[: t.index(self.tokenizer.eos_id)]
        except ValueError:
            pass
        decoded.append(self.tokenizer.decode(t))
    return decoded

@torch.inference_mode()
def OneLLM_forward_inference(self, tokens: torch.Tensor, start_pos: int, image=None, audio=None, video=None, modal='image'):
    # modal = modal[0] if isinstance(modal, list) else modal
    _bsz, seqlen = tokens.shape
    assert _bsz == 1

    if start_pos == 0:
        # kv cache will not re-allocate if size is unchanged
        self._allocate_kv_cache(_bsz)
    h = self.tok_embeddings(tokens)
    self.freqs_cis = self.freqs_cis.to(h.device)

    if image is None and audio is None and video is None:
        if start_pos == 0:
            self.cache_image_words = 0
            freqs_cis = self.freqs_cis[0: seqlen]
        else:
            # if image was not None when start_pos=0,
            # the offset should be added to start_pos within later forward_inference calls
            start_pos = start_pos + self.cache_image_words
            freqs_cis = self.freqs_cis[start_pos: start_pos + seqlen]
    else:
        modal_inputs = [image, audio, video]
        modal_ids = [IMAGE_ID, AUDIO_ID, VIDEO_ID]
        modals = ['image', 'audio', 'video']

        for modal_input, modal_id, modal in zip(modal_inputs, modal_ids, modals):
            if modal_input is not None:
                modal_tokens = self.encode_image(modal_input, modal)
                modal_tokens = modal_tokens.reshape(-1, h.shape[-1])
                special_mask = tokens == modal_id
                special_mask = special_mask[:, :, None].repeat(1,1, h.shape[-1])
                modal_tokens = modal_tokens.to(h.device, h.dtype)
                h = h.masked_scatter(special_mask, modal_tokens)
                self.cache_image_words += modal_tokens.shape[0]
        seqlen = h.shape[1]
        freqs_cis = self.freqs_cis[0: seqlen]

    # freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]

    mask = None
    if seqlen > 1:
        mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=tokens.device)
        mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)

    for layer in self.layers:
        h = layer(h, start_pos, freqs_cis, mask)
    h = self.norm(h)
    output = self.output(h[:, -1, :])  # only compute last logits
    return output.float()

class onellm_evaluation(nn.Module):
    def __init__(self, model_path="./OneLLM-7B", image_folder=None, video_folder=None, audio_folder=None):
        super().__init__()

        mp.set_start_method("spawn")
        dist.init_process_group(
            backend="nccl", rank=0, world_size=1,
            init_method=f"tcp://127.0.0.1:23560")
        fs_init.initialize_model_parallel(1)
        torch.cuda.set_device(0)
        torch.manual_seed(1)
        np.random.seed(1)

        self.target_dtype = {
            "bf16": torch.bfloat16,
            "fp16": torch.float16
        }['fp16']
        with default_tensor_type(dtype=self.target_dtype, device="cuda"):
            self.model = MetaModel("onellm", "config/llama2/7B.json", None, "config/llama2/tokenizer.model")

        # replace OneLLM with modified generation function
        self.model.generate = types.MethodType(MetaModel_generate, self.model)
        self.model.llma.forward_inference = types.MethodType(OneLLM_forward_inference, self.model.llma)

        print("Loading pretrained weights ...")
        checkpoint = torch.load(os.path.join(model_path, "consolidated.00-of-01.pth"), map_location='cpu')
        msg = self.model.load_state_dict(checkpoint, strict=False)
        print("load result:\n", msg)
        self.model.half().cuda().eval()

        self.image_folder = image_folder
        self.video_folder = video_folder
        self.audio_folder = audio_folder
        self.question_prompt = "Answer with the option's letter from the given choices directly."

    def evaluate_all(self, image_path, audio_path, video_path, question, options):
        # image_path for list of image path
        # audio_path for list of audio path
        # question for question
        # options for [option_A, option_B, option_C, option_D]

        option_text = "A. " + options[0] + "\n" + "B. " + options[1] + "\n" + "C. " + options[2] + "\n" + "D. " + options[3] + "\n" 
        text = question + "\n" + option_text + self.question_prompt

        for index in range(len(image_path)):
            text = text.replace(f"[img{index+1}]", " " + " ".join([IMAGE_TAG] * NUM_MODAL_TOKEN) + " ") # "Options: A. VI VI VI ... B. VI VI ..." 
        for index in range(len(audio_path)):
            text = text.replace(f'[audio{index+1}]', " " + " ".join([AUDIO_TAG] * NUM_MODAL_TOKEN) + " ")
        for index in range(len(video_path)):
            text = text.replace(f'[video{index+1}]', " " + " ".join([VIDEO_TAG] * NUM_MODAL_TOKEN) + " ")

        prompts = []
        conv = conv_templates["v1"].copy()        
        conv.append_message(conv.roles[0], text)
        conv.append_message(conv.roles[1], None)
        prompts.append(conv.get_prompt())

        image_tensors = []
        for image_file in image_path:
            image_file = os.path.join(self.image_folder, image_file)
            image = Image.open(image_file).convert('RGB')
            image = T_resized_center_crop(image)
            image_tensors.append(image)
        if len(image_tensors) > 0:
            image_tensors = torch.stack(image_tensors).cuda().to(self.target_dtype)
        else:
            image_tensors = None

        audio_tensors = []
        for audio_file in audio_path:
            audio_file = os.path.join(self.audio_folder, audio_file)
            audio = load_audio(audio_file)
            audio_tensors.append(audio)
        if len(audio_tensors) > 0:
            audio_tensors = torch.stack(audio_tensors).cuda().to(self.target_dtype)
        else:
            audio_tensors = None

        video_tensors = []
        for video_file in video_path:
            video_file = os.path.join(self.video_folder, video_file)
            video = load_video(video_file)
            video_tensors.append(video)
        if len(video_tensors) > 0:
            video_tensors = torch.stack(video_tensors).cuda().to(self.target_dtype)
        else:
            video_tensors = None

        with torch.cuda.amp.autocast(dtype=self.target_dtype):
            responses = self.model.generate(
                prompts,
                images=image_tensors,
                audios=audio_tensors,
                videos=video_tensors,
                max_gen_len=32,
                temperature=0.0,
                top_p=0.95)
            outputs = []
            for response, prompt in zip(responses, prompts):
                response = response[len(prompt):].split('###')[0]
                response = response.strip()
                outputs.append(response)
        return outputs[0]

    def evaluate_image_audio_text(self, image_path, audio_path, question, options):
        return self.evaluate_all(image_path, audio_path, [], question, options)

    def evaluate_video_audio_text(self, video_path, audio_path, question, options):
        return self.evaluate_all([], audio_path, video_path, question, options)

if __name__ == "__main__":
    model_path = "multimodal_llama2_7B/llama2-7B_img224-patch16_llama_clip_resampler_moe_bsz512-5120_lr2e-5_warm0.05_clip2_X_v20_finetune_8gpu/epoch_0_iter_000043000"
    onellm = onellm_evaluation(model_path, "examples/imgs", "examples/videos", "examples/audios")

    question = "Please select the image below that best matches the audio: [audio1] from the first image: [img1], the second image: [img2], the third image: [img3] and the fourth image: [img4]."
    # options = ["[img1]", "[img2]", "[img3]", "[img4]"]
    options = ["the first image", "the second image", "the third image", "the fourth image"]

    audio_path = ['dog.wav']
    image_paths = ["dog.jpg", 'cat.jpg', 'bird.jpg', 'rabbit.jpg']

    res = onellm.evaluate_image_audio_text(image_paths, audio_path, question, options)

    import pdb;pdb.set_trace()
csuhan commented 5 days ago

Hi @vittoriopipoli @qixueweigitbub , hope this script can help you for mixed modal input.

xxrbudong commented 1 day ago

Thank you very much for your answer. There are also two flags about self.start_tag [modal] and self.end_tag [modal] in the single-modal inference example. Don't you need to set them in various data inference?