TRI-ML / prismatic-vlms

A flexible and efficient codebase for training visually-conditioned language models (VLMs)
MIT License
327 stars 93 forks source link

Quantization support #12

Open show981111 opened 3 months ago

show981111 commented 3 months ago

Hi, thank you for the awesome work. I was wondering if there is a quantized version of prismatic, or if I can quantize the LLM backbone at least. I saw that for inference, it is loading the weights using load_state_dict, so I am not sure how to approach quantization. Any insight would be helpful. Thanks!

siddk commented 2 months ago

This is a good question -- I would love to support this, but don't have too much experience loading LLMs in 4-bit/8-bit precision. If you can link me to some code for loading e.g., LLaMa-2 in 8-bit precision, I can see what would make sense!

djghosh13 commented 2 months ago

If I understand correctly, LlamaForCausalLM already supports easy quantization. Something like

quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True)
LlamaForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    torch_dtype=torch.bfloat16,
    quantization_config=quantization_config
)

works for me to load LLaMA-2 in 8-bit (or 4-bit if you specify in the BitsAndBytesConfig parameters).

The docs for BitsAndBytesConfig is here: https://huggingface.co/docs/transformers/en/main_classes/quantization#transformers.BitsAndBytesConfig