haotian-liu / LLaVA

[NeurIPS'23 Oral] Visual Instruction Tuning (LLaVA) built towards GPT-4V level capabilities and beyond.
https://llava.hliu.cc
Apache License 2.0
19.22k stars 2.11k forks source link

[Feature request] Support for load_in_8bit #132

Open Samin100 opened 1 year ago

Samin100 commented 1 year ago

feature

Being able to run LLaVA in 8-bit mode would allow better support for inference on consumer GPUs due to lower memory requirements. Passing in load_in_8bit=True to from_pretrained in the eval/run_llava.py doesn't work. I'm testing with the 7B v1.1 model. Do you know what might need to be changed in llama.py to support 8-bit inference?

# 7B v1.1 model
model = LlavaLlamaForCausalLM.from_pretrained(
  model_name, 
  low_cpu_mem_usage=True, 
  torch_dtype=torch.float16, 
  use_cache=True, 
  device_map='auto', 
  load_in_8bit=True
).cuda()

Traceback:

  File "/home/ubuntu/llava/LLaVA/llava/eval/run_llava.py", line 191, in <module>
    eval_model(args)
  File "/home/ubuntu/llava/LLaVA/llava/eval/run_llava.py", line 146, in eval_model
    output_ids = model.generate(
  File "/home/ubuntu/miniconda3/envs/llava/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/ubuntu/miniconda3/envs/llava/lib/python3.10/site-packages/transformers/generation/utils.py", line 1462, in generate
    return self.sample(
  File "/home/ubuntu/miniconda3/envs/llava/lib/python3.10/site-packages/transformers/generation/utils.py", line 2478, in sample
    outputs = self(
  File "/home/ubuntu/miniconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/miniconda3/envs/llava/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/ubuntu/llava/LLaVA/llava/model/llava.py", line 222, in forward
    outputs = self.model(
  File "/home/ubuntu/miniconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/miniconda3/envs/llava/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/ubuntu/llava/LLaVA/llava/model/llava.py", line 133, in forward
    image_features = self.mm_projector(image_features)
  File "/home/ubuntu/miniconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/miniconda3/envs/llava/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/ubuntu/miniconda3/envs/llava/lib/python3.10/site-packages/bitsandbytes/nn/modules.py", line 320, in forward
    out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
  File "/home/ubuntu/miniconda3/envs/llava/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py", line 500, in matmul
    return MatMul8bitLt.apply(A, B, out, bias, state)
  File "/home/ubuntu/miniconda3/envs/llava/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/ubuntu/miniconda3/envs/llava/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py", line 322, in forward
    A = A.view(-1, A.shape[-1]).contiguous()
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
haotian-liu commented 1 year ago

Hi, please check out the latest code base, which supports both 4bit and 8bit inference.

https://github.com/haotian-liu/LLaVA#launch-a-model-worker-4-bit-8-bit-inference-quantized