zjysteven / lmms-finetune

A minimal codebase for finetuning large multimodal models, supporting llava-1.5/1.6, llava-interleave, llava-next-video, llava-onevision, qwen-vl, qwen2-vl, phi3-v etc.
Apache License 2.0
162 stars 21 forks source link

Pre-compute vision encoder features #18

Closed nisargshah1999 closed 2 months ago

nisargshah1999 commented 2 months ago

Hi,

Thanks for very useful and interesting work I was wondering, if there is any easy way to precompute vision encoder features for training(and load them), so memory complexity issue can be solved

zjysteven commented 2 months ago

That's a good idea and definitely possible. I can try to support such options in the following days. Meanwhile if you'd like to implement something yourself I might be able to give some high level ideas. Let me know.

nisargshah1999 commented 2 months ago

Good to hear, sure, I can work on that Let me know the pointers(and anything to not miss), new with hf modules

zjysteven commented 2 months ago

After thinking about it the best option would be that I (and our maintainers) to implement it because there will be some model loading/unloading logics involved which are coupled with the existing logics. I'm already working on it and it should be available soon.

Hi @nisargshah1999 after digging into it I realize that it might not be as straightforward as I thought to support caching vision features. The problem is that currently to support unified finetuning of multiple models we are relying on the relatively unified forward function/interface defined with HF models, and the forward function of course goes through everything including the vision encoder, the projector, and the llm. If we were to cache the vision features, we would need to have a custom forward function that feeds the features to relevant modules and bypass the vision encoder. Such custom forward function will be case-by-case and model-by-model, which would be quite cumbersome to implement.