The success of Large Language Models (LLM) has led researchers to explore Multimodal Large Language Models (MLLM) for unified visual and linguistic understanding. However, the increasing model size and computational complexity of MLLM limit their use in resource-constrained environments. Small-scale MLLM ($s$-MLLM) aims to retain the capabilities of the large-scale model ($l$-MLLM) while reducing computational demands, but resulting in a significant decline in performance. To address the aforementioned issues, we propose a novel LLaVA-KD framework to transfer knowledge from $l$-MLLM to $s$-MLLM. Specifically, we introduce Multimodal Distillation (MDist) to minimize the divergence between the visual-textual output distributions of $l$-MLLM and $s$-MLLM, and Relation Distillation (RDist) to transfer $l$-MLLM’s ability to model correlations between visual features. Additionally, we propose a three-stage training scheme to fully exploit the potential of $s$-MLLM: (1) Distilled Pre-Training to align visual-textual representations, (2) Supervised Fine-Tuning to equip the model with multimodal understanding, and (3) Distilled Fine-Tuning to further transfer $l$-MLLM capabilities. Our approach significantly improves performance without altering the small model's architecture. Extensive experiments and ablation studies validate the effectiveness of each proposed component.
Benchmarked results with SoTA MLLMs. Compared with counterparts, our \method~achieves highly competitive results than current small-scale MLLM models. AVG: The average of the nine benchmarks for comprehensive comparison except MMMU. $^\dagger$: reproduced results using the official code.
pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu118
python3 -m pip install --no-cache-dir --upgrade -r requirements.txt
python3 -m pip install numpy==1.26.2
python3 -m pip install urllib3==1.26.6
pip install ptflops
git clone https://github.com/Dao-AILab/flash-attention.git
cd ./flash-attention
python3 -m pip install wheel==0.41.3
python3 setup.py install
git clone https://github.com/bitsandbytes-foundation/bitsandbytes.git
cd ./bitsandbytes
pip install -e .
Model | Vision Encoder | LLM | CKPTs |
---|---|---|---|
LLaVA-KD-1B | siglip-so400m-patch14-384 | Qwen/Qwen1.5-0.5B | LLaVA-KD-Base-siglip-Qwen1.5-0.5B |
LLaVA-KD-2B | siglip-so400m-patch14-384 | Qwen/Qwen1.5-1.8B | LLaVA-KD-Base-siglip-Qwen1.5-1.8B |
Please evaluate the model according to Evaluation.md.
Download the Pre-trained VisualEnc, LLM, LLaVAKD weights to the ./pretrained_ckpt
. And then:
python quick_inference.py --model_path ./pretrained_ckpt/LLaVAKD_Model_Path --image_file ./image_test/img_test_1.jpg --query "What is that orange thing behind the girl?"
If you find this code useful, don't forget to star the repo and cite the paper.
@article{cai2024llava,
title={LLaVA-KD: A Framework of Distilling Multimodal Large Language Models},
author={Cai, Yuxuan and Zhang, Jiangning and He, Haoyang and He, Xinwei and Tong, Ao and Gan, Zhenye and Wang, Chengjie and Bai, Xiang},
journal={arXiv preprint arXiv:2410.16236},
year={2024}
}
We thank the great works TinyLLaVA, LLaVA for providing assistance for our research.