Visual Language Models (VLMs) have made significant progress in various downstream tasks by developing large-scale multimodal models. However, they sometimes lack reasoning and contextual learning abilities. On the other hand, Large Language Models (LLMs) have revolutionized the NLP community with their strong reasoning and contextual learning capabilities. LLMs can quickly adapt to new tasks involving inference without fine-tuning pre-trained models or parameter updates, such as question answering and commonsense reasoning.
Studying in context learning abilities contributes to VLMs' ability to generalize new knowledge in lifelong learning environments, develop learnable capabilities, and advance artificial intelligence skills. Therefore, we propose the MIC(Multimodality In-Context Learning) dataset. This is a manually constructed instruction tuning dataset supports interleaved text-image inputs, inter-related multiple image inputs, and multimodal in-context learning inputs. By finetuning VLMs on MIC, we enable them to possess multimodal in-context learning capabilities and understand complex relationships between instructions and multiple images.
MMICL(Multi-Modal In-Context Learning) is a multimodal vision-language model that incorporates blip2/instrcutblip. It has the ability to analyze and understand multiple images, as well as follow instructions. MMICL outperforms the VL model of the same size and performs exceptionally well on complex visual reasoning datasets. It achieves state-of-the-art performance on both multimodal task leaderboards and a wide range of vision-language tasks. Furthermore, it showcases new capabilities in video understanding and multimodal in-context learning (M-ICL).
🔥 Further details of model and dataset will be released ASAP.
Model: MMICL-FLANT5XXL; MMICL-Tiny
Data: MIC_5M; MIC_Sampled
Demo for MMICL: playground for MMICL-FLANT5XXL
Capability of multiple images refering and reasoning
Manually constructed In-context instruction tuning dataset
Visual Encoder: VIT-L from CLIP/ ViT-G/14 from EVA-CLIP
Pre-trained LLM: FlanT5-XL/ FlanT5-XXL/ Vicuna-7B/ Vicuna-13B
Perception(Overall) | Existence | Count | Position | Color | Posters | Celebrity | Scene | Landmark | Artwork | OCR | |
---|---|---|---|---|---|---|---|---|---|---|---|
SCORE | 1376.00 | 175.00 | 143.33 | 73.33 | 171.67 | 130.95 | 145.88 | 152.75 | 135.08 | 133.00 | 115.00 |
RANK | 1 | 4 | 2 | 5 | 1 | 7 | 3 | 6 | 8 | 3 | 2 |
Cognition(Overall) | Commonsense Reasoning | Numerical Calculation | Text Translation | Code Reasoning | |
---|---|---|---|---|---|
SCORE | 360.36 | 117.86 | 62.50 | 107.50 | 72.50 |
RANK | 1 | 2 | 5 | 1 | 4 |
Method | Vision Model | Language Model | Overall | LR | AR | RR | FP-S | FP-C | CP |
---|---|---|---|---|---|---|---|---|---|
MMICL | EVA-G | FLANT5-XXL | 65.2 | 44.3 | 77.9 | 64.8 | 66.5 | 53.6 | 70.6 |
JiuTian | EVA-G | FLANT5-XXL | 64.7 | 46.6 | 76.5 | 66.7 | 66.5 | 51.6 | 68.7 |
Method | Vision Model | Language Model | Overall | LR | AR | RR | FP-S | FP-C | CP |
---|---|---|---|---|---|---|---|---|---|
MMICL | EVA-G | FLANT5-XXL | 67.9 | 49.2 | 77.9 | 73.0 | 66.7 | 57.2 | 77.2 |
JiuTian | EVA-G | FLANT5-XXL | 67.1 | 46.7 | 69.7 | 71.3 | 67.3 | 58.6 | 75.8 |
Since experiments reported in our paper are all conducted on NVIDIA DGX-A40 servers (which might be difficult to acquire), we reimplement MMICL with:
We conduct our experiment with Anaconda3. If you have installed Anaconda3, then create the environment for MMICL:
conda env create -f environment.yml
Using multiple data source such as: VCR, VQAv2, GQA, COCO, NLVR2, OKVQK, FILCKR, STVQA, MSRVTT, MSRVTTQA, TextVQA, RefCOCO, WikiArt, DiffusionDB, VSR, LLaVa-Instruct, MiniImagenet We tranform it into few shot style and stored it into jsonl files: It forms the all data in to multi instruction style with zero to few-shot form data. runing the preprocessed script and change the data into raw arrow file for further training:
# The detail of MIC dataset will be added ASAP
python data_preprocess.py
1. Interleaved Image-Text Data
Input: Image 0 is <image0> {image 0}
...
Image j is <imagej> {image j}
{question}
MMICL: {answer}
2. In−Context Demonstration Data
Input: Image 0 is <image0> {image 0}.
{question}
{answer}
...
Image j is <imagej> {image j}.
{question}
MMICL: {answer}
interleaved image text example
relevant image example
in context example
Run training scripts in run_script :
# For training the MMICL in FLANT5-xxl version
bash run_script/flickr/deep_speed_instructblip_t5xxl.sh
# For T5 based model
from model.instructblip import InstructBlipConfig, InstructBlipModel, InstructBlipPreTrainedModel,InstructBlipForConditionalGeneration,InstructBlipProcessor
import datasets
import json
import transformers
from PIL import Image
import torch
model_type="instructblip"
model_ckpt="BleachNick/MMICL-Instructblip-T5-xxl"
processor_ckpt = "Salesforce/instructblip-flan-t5-xxl"
config = InstructBlipConfig.from_pretrained(model_ckpt )
if 'instructblip' in model_type:
model = InstructBlipForConditionalGeneration.from_pretrained(
model_ckpt,
config=config).to('cuda:0',dtype=torch.bfloat16)
image_palceholder="图"
sp = [image_palceholder]+[f"<image{i}>" for i in range(20)]
processor = InstructBlipProcessor.from_pretrained(
processor_ckpt
)
sp = sp+processor.tokenizer.additional_special_tokens[len(sp):]
processor.tokenizer.add_special_tokens({'additional_special_tokens':sp})
if model.qformer.embeddings.word_embeddings.weight.shape[0] != len(processor.qformer_tokenizer):
model.qformer.resize_token_embeddings(len(processor.qformer_tokenizer))
replace_token="".join(32*[image_palceholder])
image = Image.open ("images/cal_num1.png")
image1 = Image.open ("images/cal_num2.png")
image2 = Image.open ("images/cal_num3.png")
images = [image,image1,image2]
prompt = [f'Use the image 0: <image0>{replace_token},image 1: <image1>{replace_token} and image 2: <image2>{replace_token} as a visual aid to help you calculate the equation accurately. image 0 is 2+1=3.\nimage 1 is 5+6=11.\nimage 2 is"']
prompt = " ".join(prompt)
inputs = processor(images=images, text=prompt, return_tensors="pt")
inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16)
inputs['img_mask'] = torch.tensor([[1 for i in range(len(images))]])
inputs['pixel_values'] = inputs['pixel_values'].unsqueeze(0)
inputs = inputs.to('cuda:0')
outputs = model.generate(
pixel_values = inputs['pixel_values'],
input_ids = inputs['input_ids'],
attention_mask = inputs['attention_mask'],
img_mask = inputs['img_mask'],
do_sample=False,
max_length=50,
min_length=1,
set_min_padding_size =False,
)
generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
print(generated_text)
3x6=18"
image = Image.open ("images/chinchilla.png")
image1 = Image.open ("images/shiba.png")
image2 = Image.open ("images/flamingo.png")
images = [image,image1,image2]
images = [image,image1,image2]
prompt = [f'image 0 is <image0>{replace_token},image 1 is <image1>{replace_token},image 2 is <image2>{replace_token}. Question: <image0> is a chinchilla. They are mainly found in Chile.\n Question: <image1> is a shiba. They are very popular in Japan.\nQuestion: image 2 is']
prompt = " ".join(prompt)
inputs = processor(images=images, text=prompt, return_tensors="pt")
inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16)
inputs['img_mask'] = torch.tensor([[1 for i in range(len(images))]])
inputs['pixel_values'] = inputs['pixel_values'].unsqueeze(0)
inputs = inputs.to('cuda:0')
outputs = model.generate(
pixel_values = inputs['pixel_values'],
input_ids = inputs['input_ids'],
attention_mask = inputs['attention_mask'],
img_mask = inputs['img_mask'],
do_sample=False,
max_length=50,
min_length=1,
set_min_padding_size =False,
)
generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
print(generated_text)
a flamingo. They are native to South America and are known for their bright red plumage and distinctive call.
image = Image.open ("images/flamingo_photo.png")
image1 = Image.open ("images/flamingo_cartoon.png")
image2 = Image.open ("images/flamingo_3d.png")
images = [image,image1,image2]
prompt = [f'Use the image 0: <image0>{replace_token}, image 1: <image1>{replace_token} and image 2: <image2>{replace_token} as a visual aids to help you answer the question. Question: Give the reason why image 0, image 1 and image 2 are different? Answer:']
prompt = " ".join(prompt)
inputs = processor(images=images, text=prompt, return_tensors="pt")
inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16)
inputs['img_mask'] = torch.tensor([[1 for i in range(len(images))]])
inputs['pixel_values'] = inputs['pixel_values'].unsqueeze(0)
inputs = inputs.to('cuda:0')
outputs = model.generate(
pixel_values = inputs['pixel_values'],
input_ids = inputs['input_ids'],
attention_mask = inputs['attention_mask'],
img_mask = inputs['img_mask'],
do_sample=False,
max_length=80,
min_length=50,
num_beams=8,
set_min_padding_size =False,
)
generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
print(generated_text)
Output:
image 0 is a photo of a flamingo standing in the water, image 1 is a cartoon drawing of a flamingo and image 2 is a low polygon count 3d model animation
📑 If you find our projects helpful to your research, please consider citing:
@article{zhao2023mmicl,
title={MMICL: Empowering Vision-language Model with Multi-Modal In-Context Learning},
author={Zhao, Haozhe and Cai, Zefan and Si, Shuzheng and Ma, Xiaojian and An, Kaikai and Chen, Liang and Liu, Zixuan and Wang, Sheng and Han, Wenjuan and Chang, Baobao},
journal={arXiv preprint arXiv:2309.07915},
year={2023}
}