vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
30.79k stars 4.68k forks source link

[Feature]: Quantization support for LLaVA OneVision #9324

Open salvaba94 opened 1 month ago

salvaba94 commented 1 month ago

🚀 The feature, motivation and pitch

I'm working on applications that must run locally in resource-limited HW. Threrefore, quantization becomes essential. Such applications need from multimodal video-text processing. The candidate model in question is LLaVA OneVision. However, it does not support BitsAndBytes quantization yet.

Model LLaVA-OneVision https://huggingface.co/llava-hf/llava-onevision-qwen2-7b-ov-hf

Challenges AFAIK Siglip, the multimodal projector and Qwen2 need from quantization support. Perhaps it is also useful to enable quantization per module to quantize only the language part.

Alternatives

No response

Additional context

Trying to load a pre-quantized LLaVA OneVision model into vLLM throws:


AttributeError: Model LlavaOnevisionForConditionalGeneration does not support BitsAndBytes quantization yet.

Before submitting a new issue...

DarkLight1337 commented 1 month ago

Which layers does it quantize specifically? Can you suggest what we need to set in bitsandbytes_stacked_params_mapping to get this to work? cc @mgoin

salvaba94 commented 1 month ago

From what I've been looking into, from the vision tower, the only thing it does not quantize are the embeddings. It quantizes all the SigLip self-attention and the MLP blocks. The multimodal projector linear layers are totally quantized. And for the Qwen2 language model, everything except for the head is quantized. This again includes self-attention and MLP. See below the quantized model:

LlavaOnevisionForConditionalGeneration(
  (vision_tower): SiglipVisionModel(
    (vision_model): SiglipVisionTransformer(
      (embeddings): SiglipVisionEmbeddings(
        (patch_embedding): Conv2d(3, 1152, kernel_size=(14, 14), stride=(14, 14), padding=valid)
        (position_embedding): Embedding(729, 1152)
      )
      (encoder): SiglipEncoder(
        (layers): ModuleList(
          (0-25): 26 x SiglipEncoderLayer(
            (self_attn): SiglipFlashAttention2(
              (k_proj): Linear4bit(in_features=1152, out_features=1152, bias=True)
              (v_proj): Linear4bit(in_features=1152, out_features=1152, bias=True)
              (q_proj): Linear4bit(in_features=1152, out_features=1152, bias=True)
              (out_proj): Linear4bit(in_features=1152, out_features=1152, bias=True)
            )
            (layer_norm1): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
            (mlp): SiglipMLP(
              (activation_fn): PytorchGELUTanh()
              (fc1): Linear4bit(in_features=1152, out_features=4304, bias=True)
              (fc2): Linear4bit(in_features=4304, out_features=1152, bias=True)
            )
            (layer_norm2): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
          )
        )
      )
      (post_layernorm): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
    )
  )
  (multi_modal_projector): LlavaOnevisionMultiModalProjector(
    (linear_1): Linear4bit(in_features=1152, out_features=896, bias=True)
    (act): GELUActivation()
    (linear_2): Linear4bit(in_features=896, out_features=896, bias=True)
  )
  (language_model): Qwen2ForCausalLM(
    (model): Qwen2Model(
      (embed_tokens): Embedding(152000, 896)
      (layers): ModuleList(
        (0-23): 24 x Qwen2DecoderLayer(
          (self_attn): Qwen2FlashAttention2(
            (q_proj): Linear4bit(in_features=896, out_features=896, bias=True)
            (k_proj): Linear4bit(in_features=896, out_features=128, bias=True)
            (v_proj): Linear4bit(in_features=896, out_features=128, bias=True)
            (o_proj): Linear4bit(in_features=896, out_features=896, bias=False)
            (rotary_emb): Qwen2RotaryEmbedding()
          )
          (mlp): Qwen2MLP(
            (gate_proj): Linear4bit(in_features=896, out_features=4864, bias=False)
            (up_proj): Linear4bit(in_features=896, out_features=4864, bias=False)
            (down_proj): Linear4bit(in_features=4864, out_features=896, bias=False)
            (act_fn): SiLU()
          )
          (input_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
          (post_attention_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
        )
      )
      (norm): Qwen2RMSNorm((896,), eps=1e-06)
      (rotary_emb): Qwen2RotaryEmbedding()
    )
    (lm_head): Linear(in_features=896, out_features=152000, bias=False)
  )
)

Tweaking some arguments of the BitsAndBytesConfig structure, quantization of the vision encoder can also be skipped (the vision_tower in the above model description).

I'm not very familiar with what the bitsandbytes_stacked_params_mapping means to be. With some initial guidance I think I can try to get it working. A first approach would be to get the Qwen2 and SigLips models working separately.