Open vittoriopipoli opened 1 month ago
I have the same request. waiting for a simple demo script to run the model with audible video input.
import sys
sys.path.append('./')
import os
import json
import types
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import multiprocessing as mp
import torch.distributed as dist
from fairscale.nn.model_parallel import initialize as fs_init
from util.misc import default_tensor_type
import torchvision.transforms as transforms
from model.meta import MetaModel
from data.conversation_lib import conv_templates
from data.data_utils import make_audio_features
from data import video_utils
def load_audio(audio_path):
fbank = make_audio_features(audio_path, mel_bins=128)
fbank = fbank.transpose(0, 1)[None] #[1, 128, 1024]
return fbank
def load_video(video_path):
video_feats = video_utils.load_and_transform_video_data(
video_path, video_path, clip_duration=1, clips_per_video=5)
return video_feats[:, :, 0]
T_resized_center_crop = transforms.Compose([
transforms.Resize(
224, interpolation=transforms.InterpolationMode.BICUBIC
),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])])
IMAGE_ID = 5473 # VI
IMAGE_TAG = "VI"
AUDIO_ID = 13408 # VII
AUDIO_TAG = "VII"
VIDEO_ID = 15682 # VIII
VIDEO_TAG = "VIII"
NUM_MODAL_TOKEN = 30
def MetaModel_generate(
self,
prompts,
images=None,
audios=None,
videos=None,
max_gen_len: int = 32,
temperature: float = 0.8,
top_p: float = 0.95,
modal = ['image'],
):
bsz = len(prompts)
assert bsz == 1
params = self.llma.params
# assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
prompt_tokens = [self.tokenizer.encode(
x, bos=True, eos=False) for x in prompts]
min_prompt_size = min([len(t) for t in prompt_tokens])
max_prompt_size = max([len(t) for t in prompt_tokens])
total_len = min(params.max_seq_len, max_gen_len + max_prompt_size)
tokens = torch.full(
(bsz, total_len), self.tokenizer.pad_id).cuda().long()
for k, t in enumerate(prompt_tokens):
tokens[k, : len(t)] = torch.tensor(t).long()
input_text_mask = tokens != self.tokenizer.pad_id
start_pos = min_prompt_size
prev_pos = 0
for cur_pos in range(start_pos, total_len):
logits = self.llma.forward_inference(tokens[:, prev_pos:cur_pos], prev_pos, images if prev_pos == 0 else None, audios if prev_pos == 0 else None, videos if prev_pos == 0 else None, modal=modal)
if temperature > 0:
probs = torch.softmax(logits / temperature, dim=-1)
next_token = self.sample_top_p(probs, top_p)
else:
next_token = torch.argmax(logits, dim=-1)
next_token = next_token.reshape(-1)
# only replace token if prompt has already been generated
next_token = torch.where(
input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
)
print(next_token)
tokens[:, cur_pos] = next_token
prev_pos = cur_pos
decoded = []
for i, t in enumerate(tokens.tolist()):
# cut to max gen len
t = t[: len(prompt_tokens[i]) + max_gen_len]
# cut to eos tok if any
try:
t = t[: t.index(self.tokenizer.eos_id)]
except ValueError:
pass
decoded.append(self.tokenizer.decode(t))
return decoded
@torch.inference_mode()
def OneLLM_forward_inference(self, tokens: torch.Tensor, start_pos: int, image=None, audio=None, video=None, modal='image'):
# modal = modal[0] if isinstance(modal, list) else modal
_bsz, seqlen = tokens.shape
assert _bsz == 1
if start_pos == 0:
# kv cache will not re-allocate if size is unchanged
self._allocate_kv_cache(_bsz)
h = self.tok_embeddings(tokens)
self.freqs_cis = self.freqs_cis.to(h.device)
if image is None and audio is None and video is None:
if start_pos == 0:
self.cache_image_words = 0
freqs_cis = self.freqs_cis[0: seqlen]
else:
# if image was not None when start_pos=0,
# the offset should be added to start_pos within later forward_inference calls
start_pos = start_pos + self.cache_image_words
freqs_cis = self.freqs_cis[start_pos: start_pos + seqlen]
else:
modal_inputs = [image, audio, video]
modal_ids = [IMAGE_ID, AUDIO_ID, VIDEO_ID]
modals = ['image', 'audio', 'video']
for modal_input, modal_id, modal in zip(modal_inputs, modal_ids, modals):
if modal_input is not None:
modal_tokens = self.encode_image(modal_input, modal)
modal_tokens = modal_tokens.reshape(-1, h.shape[-1])
special_mask = tokens == modal_id
special_mask = special_mask[:, :, None].repeat(1,1, h.shape[-1])
modal_tokens = modal_tokens.to(h.device, h.dtype)
h = h.masked_scatter(special_mask, modal_tokens)
self.cache_image_words += modal_tokens.shape[0]
seqlen = h.shape[1]
freqs_cis = self.freqs_cis[0: seqlen]
# freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
mask = None
if seqlen > 1:
mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=tokens.device)
mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
for layer in self.layers:
h = layer(h, start_pos, freqs_cis, mask)
h = self.norm(h)
output = self.output(h[:, -1, :]) # only compute last logits
return output.float()
class onellm_evaluation(nn.Module):
def __init__(self, model_path="./OneLLM-7B", image_folder=None, video_folder=None, audio_folder=None):
super().__init__()
mp.set_start_method("spawn")
dist.init_process_group(
backend="nccl", rank=0, world_size=1,
init_method=f"tcp://127.0.0.1:23560")
fs_init.initialize_model_parallel(1)
torch.cuda.set_device(0)
torch.manual_seed(1)
np.random.seed(1)
self.target_dtype = {
"bf16": torch.bfloat16,
"fp16": torch.float16
}['fp16']
with default_tensor_type(dtype=self.target_dtype, device="cuda"):
self.model = MetaModel("onellm", "config/llama2/7B.json", None, "config/llama2/tokenizer.model")
# replace OneLLM with modified generation function
self.model.generate = types.MethodType(MetaModel_generate, self.model)
self.model.llma.forward_inference = types.MethodType(OneLLM_forward_inference, self.model.llma)
print("Loading pretrained weights ...")
checkpoint = torch.load(os.path.join(model_path, "consolidated.00-of-01.pth"), map_location='cpu')
msg = self.model.load_state_dict(checkpoint, strict=False)
print("load result:\n", msg)
self.model.half().cuda().eval()
self.image_folder = image_folder
self.video_folder = video_folder
self.audio_folder = audio_folder
self.question_prompt = "Answer with the option's letter from the given choices directly."
def evaluate_all(self, image_path, audio_path, video_path, question, options):
# image_path for list of image path
# audio_path for list of audio path
# question for question
# options for [option_A, option_B, option_C, option_D]
option_text = "A. " + options[0] + "\n" + "B. " + options[1] + "\n" + "C. " + options[2] + "\n" + "D. " + options[3] + "\n"
text = question + "\n" + option_text + self.question_prompt
for index in range(len(image_path)):
text = text.replace(f"[img{index+1}]", " " + " ".join([IMAGE_TAG] * NUM_MODAL_TOKEN) + " ") # "Options: A. VI VI VI ... B. VI VI ..."
for index in range(len(audio_path)):
text = text.replace(f'[audio{index+1}]', " " + " ".join([AUDIO_TAG] * NUM_MODAL_TOKEN) + " ")
for index in range(len(video_path)):
text = text.replace(f'[video{index+1}]', " " + " ".join([VIDEO_TAG] * NUM_MODAL_TOKEN) + " ")
prompts = []
conv = conv_templates["v1"].copy()
conv.append_message(conv.roles[0], text)
conv.append_message(conv.roles[1], None)
prompts.append(conv.get_prompt())
image_tensors = []
for image_file in image_path:
image_file = os.path.join(self.image_folder, image_file)
image = Image.open(image_file).convert('RGB')
image = T_resized_center_crop(image)
image_tensors.append(image)
if len(image_tensors) > 0:
image_tensors = torch.stack(image_tensors).cuda().to(self.target_dtype)
else:
image_tensors = None
audio_tensors = []
for audio_file in audio_path:
audio_file = os.path.join(self.audio_folder, audio_file)
audio = load_audio(audio_file)
audio_tensors.append(audio)
if len(audio_tensors) > 0:
audio_tensors = torch.stack(audio_tensors).cuda().to(self.target_dtype)
else:
audio_tensors = None
video_tensors = []
for video_file in video_path:
video_file = os.path.join(self.video_folder, video_file)
video = load_video(video_file)
video_tensors.append(video)
if len(video_tensors) > 0:
video_tensors = torch.stack(video_tensors).cuda().to(self.target_dtype)
else:
video_tensors = None
with torch.cuda.amp.autocast(dtype=self.target_dtype):
responses = self.model.generate(
prompts,
images=image_tensors,
audios=audio_tensors,
videos=video_tensors,
max_gen_len=32,
temperature=0.0,
top_p=0.95)
outputs = []
for response, prompt in zip(responses, prompts):
response = response[len(prompt):].split('###')[0]
response = response.strip()
outputs.append(response)
return outputs[0]
def evaluate_image_audio_text(self, image_path, audio_path, question, options):
return self.evaluate_all(image_path, audio_path, [], question, options)
def evaluate_video_audio_text(self, video_path, audio_path, question, options):
return self.evaluate_all([], audio_path, video_path, question, options)
if __name__ == "__main__":
model_path = "multimodal_llama2_7B/llama2-7B_img224-patch16_llama_clip_resampler_moe_bsz512-5120_lr2e-5_warm0.05_clip2_X_v20_finetune_8gpu/epoch_0_iter_000043000"
onellm = onellm_evaluation(model_path, "examples/imgs", "examples/videos", "examples/audios")
question = "Please select the image below that best matches the audio: [audio1] from the first image: [img1], the second image: [img2], the third image: [img3] and the fourth image: [img4]."
# options = ["[img1]", "[img2]", "[img3]", "[img4]"]
options = ["the first image", "the second image", "the third image", "the fourth image"]
audio_path = ['dog.wav']
image_paths = ["dog.jpg", 'cat.jpg', 'bird.jpg', 'rabbit.jpg']
res = onellm.evaluate_image_audio_text(image_paths, audio_path, question, options)
import pdb;pdb.set_trace()
Hi @vittoriopipoli @qixueweigitbub , hope this script can help you for mixed modal input.
Thank you very much for your answer. There are also two flags about self.start_tag [modal] and self.end_tag [modal] in the single-modal inference example. Don't you need to set them in various data inference?
Hi @csuhan,
I recently came across your paper presented at CVPR2024, where you introduced the OneLLM model. I found your work highly interesting and particularly relevant to my research. I am keen to conduct a detailed study on the scenarios where OneLLM processes inputs from multiple modalities, such as the audio-video-text cases described in Table 4 of your paper.
However, upon reviewing the resources available at this repository, I was unable to locate the scripts that handle experiments involving more than two modalities. I was wondering if you could kindly share the code for three-modality cases or guide me on how to proceed in setting up such experiments.
I would greatly appreciate any assistance or guidance you can provide on this matter. Thank you for your time, and I look forward to your response.