haotian-liu / LLaVA

[NeurIPS'23 Oral] Visual Instruction Tuning (LLaVA) built towards GPT-4V level capabilities and beyond.
https://llava.hliu.cc
Apache License 2.0
19.37k stars 2.13k forks source link

[Usage] llava-7b 4bit quantization inference problem #829

Open Leon-Sander opened 10 months ago

Leon-Sander commented 10 months ago

Issue:

When loading the model in 4bit, I am facing this error: RuntimeError: expected scalar type Float but found Half.

Reproduce the code:

from llava.model.builder import load_pretrained_model
from io import BytesIO
from PIL import Image
def load_image(image_file):
    if image_file.startswith('http') or image_file.startswith('https'):
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert('RGB')
    else:
        image = Image.open(image_file).convert('RGB')
    return image

import torch
# Set the model path and name
model_path = "liuhaotian/llava-v1.5-7b"
model_name = "llava-v1.5-7b"

# Load the model
tokenizer, model, image_processor, _ = load_pretrained_model(model_path, model_name=model_name, model_base=None, load_4bit=True, device_map="auto", device="cuda")

# Load your image
image_path = "./gainer_schlo.JPG"
image_data = load_image(image_path)
image_tensor = image_processor.preprocess(image_data, return_tensors='pt')['pixel_values'].cuda()

# Prepare the text input
prompt = "Desribe what you see in the image."
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#model = model.to(device)  
image_tensor = image_tensor.to(device)

# Prepare the text input
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

# Generate output
generated_ids = model.generate(
    input_ids=input_ids,
    images=image_tensor,
    max_new_tokens=256,  # specify the max number of new tokens
    num_return_sequences=1  # specify the number of sequences to return
)

# Convert generated tokens to text
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
print(generated_text)

Traceback:

RuntimeError                              Traceback (most recent call last)
[/test.ipynb](https://file+.vscode-resource.vscode-cdn.net/test.ipynb) Cell 42 line 3
     [31](vscode-notebook-cell:/test.ipynb#X63sZmlsZQ%3D%3D?line=30) input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
     [33](vscode-notebook-cell:/test.ipynb#X63sZmlsZQ%3D%3D?line=32) # Generate output
---> [34](vscode-notebook-cell:/test.ipynb#X63sZmlsZQ%3D%3D?line=33) generated_ids = model.generate(
     [35](vscode-notebook-cell:/test.ipynb#X63sZmlsZQ%3D%3D?line=34)     input_ids=input_ids,
     [36](vscode-notebook-cell:/test.ipynb#X63sZmlsZQ%3D%3D?line=35)     images=image_tensor,
     [37](vscode-notebook-cell:/test.ipynb#X63sZmlsZQ%3D%3D?line=36)     max_new_tokens=256,  # specify the max number of new tokens
     [38](vscode-notebook-cell:/test.ipynb#X63sZmlsZQ%3D%3D?line=37)     num_return_sequences=1  # specify the number of sequences to return
     [39](vscode-notebook-cell:/test.ipynb#X63sZmlsZQ%3D%3D?line=38) )
     [41](vscode-notebook-cell:/test.ipynb#X63sZmlsZQ%3D%3D?line=40) # Convert generated tokens to text
     [42](vscode-notebook-cell:/test.ipynb#X63sZmlsZQ%3D%3D?line=41) generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)

File [~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/utils/_contextlib.py:115](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/utils/_contextlib.py:115), in context_decorator.<locals>.decorate_context(*args, **kwargs)
    [112](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/utils/_contextlib.py:112) @functools.wraps(func)
    [113](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/utils/_contextlib.py:113) def decorate_context(*args, **kwargs):
    [114](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/utils/_contextlib.py:114)     with ctx_factory():
--> [115](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/utils/_contextlib.py:115)         return func(*args, **kwargs)

File [~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:1538](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:1538), in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, **kwargs)
   [1532](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:1532)         raise ValueError(
   [1533](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:1533)             "num_return_sequences has to be 1 when doing greedy search, "
   [1534](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:1534)             f"but is {generation_config.num_return_sequences}."
   [1535](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:1535)         )
   [1537](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:1537)     # 11. run greedy search
-> [1538](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:1538)     return self.greedy_search(
   [1539](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:1539)         input_ids,
   [1540](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:1540)         logits_processor=logits_processor,
   [1541](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:1541)         stopping_criteria=stopping_criteria,
   [1542](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:1542)         pad_token_id=generation_config.pad_token_id,
   [1543](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:1543)         eos_token_id=generation_config.eos_token_id,
   [1544](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:1544)         output_scores=generation_config.output_scores,
   [1545](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:1545)         return_dict_in_generate=generation_config.return_dict_in_generate,
   [1546](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:1546)         synced_gpus=synced_gpus,
   [1547](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:1547)         streamer=streamer,
   [1548](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:1548)         **model_kwargs,
   [1549](https://file+.vscode-resource.vscode-cdn.net/~/coRuntimeError: expected scalar type Float but found Halfde/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:1549)     )
   [1551](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:1551) elif is_contrastive_search_gen_mode:
   [1552](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:1552)     if generation_config.num_return_sequences > 1:

File [~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:2362](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:2362), in GenerationMixin.greedy_search(self, input_ids, logits_processor, stopping_criteria, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, synced_gpus, streamer, **model_kwargs)
   [2359](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:2359) model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
   [2361](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:2361) # forward pass to get next token
-> [2362](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:2362) outputs = self(
   [2363](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:2363)     **model_inputs,
   [2364](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:2364)     return_dict=True,
   [2365](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:2365)     output_attentions=output_attentions,
   [2366](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:2366)     output_hidden_states=output_hidden_states,
   [2367](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:2367) )
   [2369](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:2369) if synced_gpus and this_peer_finished:
   [2370](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:2370)     continue  # don't waste resources running the code we don't need

File [~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/nn/modules/module.py:1501](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/nn/modules/module.py:1501), in Module._call_impl(self, *args, **kwargs)
   [1496](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/nn/modules/module.py:1496) # If we don't have any hooks, we want to skip the rest of the logic in
   [1497](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/nn/modules/module.py:1497) # this function, and just call forward.
   [1498](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/nn/modules/module.py:1498) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   [1499](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/nn/modules/module.py:1499)         or _global_backward_pre_hooks or _global_backward_hooks
   [1500](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/nn/modules/module.py:1500)         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1501](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/nn/modules/module.py:1501)     return forward_call(*args, **kwargs)
   [1502](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/nn/modules/module.py:1502) # Do not call functions when jit is used
   [1503](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/nn/modules/module.py:1503) full_backward_hooks, non_full_backward_hooks = [], []

File [~/code/LLaVA/venv_llava/lib/python3.10/site-packages/accelerate/hooks.py:165](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/accelerate/hooks.py:165), in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
    [163](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/accelerate/hooks.py:163)         output = old_forward(*args, **kwargs)
    [164](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/accelerate/hooks.py:164) else:
--> [165](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/accelerate/hooks.py:165)     output = old_forward(*args, **kwargs)
    [166](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/accelerate/hooks.py:166) return module._hf_hook.post_forward(module, output)

File [~/code/LLaVA/llava/model/language_model/llava_llama.py:88](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/llava/model/language_model/llava_llama.py:88), in LlavaLlamaForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, images, return_dict)
     [71](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/llava/model/language_model/llava_llama.py:71) if inputs_embeds is None:
     [72](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/llava/model/language_model/llava_llama.py:72)     (
     [73](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/llava/model/language_model/llava_llama.py:73)         input_ids,
     [74](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/llava/model/language_model/llava_llama.py:74)         position_ids,
   (...)
     [85](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/llava/model/language_model/llava_llama.py:85)         images
     [86](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/llava/model/language_model/llava_llama.py:86)     )
---> [88](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/llava/model/language_model/llava_llama.py:88) return super().forward(
     [89](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/llava/model/language_model/llava_llama.py:89)     input_ids=input_ids,
     [90](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/llava/model/language_model/llava_llama.py:90)     attention_mask=attention_mask,
     [91](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/llava/model/language_model/llava_llama.py:91)     position_ids=position_ids,
     [92](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/llava/model/language_model/llava_llama.py:92)     past_key_values=past_key_values,
     [93](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/llava/model/language_model/llava_llama.py:93)     inputs_embeds=inputs_embeds,
     [94](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/llava/model/language_model/llava_llama.py:94)     labels=labels,
     [95](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/llava/model/language_model/llava_llama.py:95)     use_cache=use_cache,
     [96](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/llava/model/language_model/llava_llama.py:96)     output_attentions=output_attentions,
     [97](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/llava/model/language_model/llava_llama.py:97)     output_hidden_states=output_hidden_states,
     [98](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/llava/model/language_model/llava_llama.py:98)     return_dict=return_dict
     [99](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/llava/model/language_model/llava_llama.py:99) )

File [~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:824](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:824), in LlamaForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)
    [822](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:822)     logits = torch.cat(logits, dim=-1)
    [823](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:823) else:
--> [824](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:824)     logits = self.lm_head(hidden_states)
    [825](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:825) logits = logits.float()
    [827](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:827) loss = None

File [~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/nn/modules/module.py:1501](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/nn/modules/module.py:1501), in Module._call_impl(self, *args, **kwargs)
   [1496](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/nn/modules/module.py:1496) # If we don't have any hooks, we want to skip the rest of the logic in
   [1497](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/nn/modules/module.py:1497) # this function, and just call forward.
   [1498](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/nn/modules/module.py:1498) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   [1499](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/nn/modules/module.py:1499)         or _global_backward_pre_hooks or _global_backward_hooks
   [1500](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/nn/modules/module.py:1500)         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1501](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/nn/modules/module.py:1501)     return forward_call(*args, **kwargs)
   [1502](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/nn/modules/module.py:1502) # Do not call functions when jit is used
   [1503](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/nn/modules/module.py:1503) full_backward_hooks, non_full_backward_hooks = [], []

File [~/code/LLaVA/venv_llava/lib/python3.10/site-packages/accelerate/hooks.py:165](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/accelerate/hooks.py:165), in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
    [163](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/accelerate/hooks.py:163)         output = old_forward(*args, **kwargs)
    [164](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/accelerate/hooks.py:164) else:
--> [165](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/accelerate/hooks.py:165)     output = old_forward(*args, **kwargs)
    [166](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/accelerate/hooks.py:166) return module._hf_hook.post_forward(module, output)

File [~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/nn/modules/linear.py:114](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/nn/modules/linear.py:114), in Linear.forward(self, input)
    [113](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/nn/modules/linear.py:113) def forward(self, input: Tensor) -> Tensor:
--> [114](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/nn/modules/linear.py:114)     return F.linear(input, self.weight, self.bias)

RuntimeError: expected scalar type Float but found Half
guzixian commented 8 months ago

with torch.autocast("cuda"): trainer.train()