When loading the model in 4bit, I am facing this error: RuntimeError: expected scalar type Float but found Half.
Reproduce the code:
from llava.model.builder import load_pretrained_model
from io import BytesIO
from PIL import Image
def load_image(image_file):
if image_file.startswith('http') or image_file.startswith('https'):
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert('RGB')
else:
image = Image.open(image_file).convert('RGB')
return image
import torch
# Set the model path and name
model_path = "liuhaotian/llava-v1.5-7b"
model_name = "llava-v1.5-7b"
# Load the model
tokenizer, model, image_processor, _ = load_pretrained_model(model_path, model_name=model_name, model_base=None, load_4bit=True, device_map="auto", device="cuda")
# Load your image
image_path = "./gainer_schlo.JPG"
image_data = load_image(image_path)
image_tensor = image_processor.preprocess(image_data, return_tensors='pt')['pixel_values'].cuda()
# Prepare the text input
prompt = "Desribe what you see in the image."
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#model = model.to(device)
image_tensor = image_tensor.to(device)
# Prepare the text input
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
# Generate output
generated_ids = model.generate(
input_ids=input_ids,
images=image_tensor,
max_new_tokens=256, # specify the max number of new tokens
num_return_sequences=1 # specify the number of sequences to return
)
# Convert generated tokens to text
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
print(generated_text)
Traceback:
RuntimeError Traceback (most recent call last)
[/test.ipynb](https://file+.vscode-resource.vscode-cdn.net/test.ipynb) Cell 42 line 3
[31](vscode-notebook-cell:/test.ipynb#X63sZmlsZQ%3D%3D?line=30) input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
[33](vscode-notebook-cell:/test.ipynb#X63sZmlsZQ%3D%3D?line=32) # Generate output
---> [34](vscode-notebook-cell:/test.ipynb#X63sZmlsZQ%3D%3D?line=33) generated_ids = model.generate(
[35](vscode-notebook-cell:/test.ipynb#X63sZmlsZQ%3D%3D?line=34) input_ids=input_ids,
[36](vscode-notebook-cell:/test.ipynb#X63sZmlsZQ%3D%3D?line=35) images=image_tensor,
[37](vscode-notebook-cell:/test.ipynb#X63sZmlsZQ%3D%3D?line=36) max_new_tokens=256, # specify the max number of new tokens
[38](vscode-notebook-cell:/test.ipynb#X63sZmlsZQ%3D%3D?line=37) num_return_sequences=1 # specify the number of sequences to return
[39](vscode-notebook-cell:/test.ipynb#X63sZmlsZQ%3D%3D?line=38) )
[41](vscode-notebook-cell:/test.ipynb#X63sZmlsZQ%3D%3D?line=40) # Convert generated tokens to text
[42](vscode-notebook-cell:/test.ipynb#X63sZmlsZQ%3D%3D?line=41) generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
File [~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/utils/_contextlib.py:115](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/utils/_contextlib.py:115), in context_decorator.<locals>.decorate_context(*args, **kwargs)
[112](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/utils/_contextlib.py:112) @functools.wraps(func)
[113](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/utils/_contextlib.py:113) def decorate_context(*args, **kwargs):
[114](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/utils/_contextlib.py:114) with ctx_factory():
--> [115](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/utils/_contextlib.py:115) return func(*args, **kwargs)
File [~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:1538](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:1538), in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, **kwargs)
[1532](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:1532) raise ValueError(
[1533](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:1533) "num_return_sequences has to be 1 when doing greedy search, "
[1534](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:1534) f"but is {generation_config.num_return_sequences}."
[1535](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:1535) )
[1537](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:1537) # 11. run greedy search
-> [1538](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:1538) return self.greedy_search(
[1539](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:1539) input_ids,
[1540](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:1540) logits_processor=logits_processor,
[1541](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:1541) stopping_criteria=stopping_criteria,
[1542](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:1542) pad_token_id=generation_config.pad_token_id,
[1543](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:1543) eos_token_id=generation_config.eos_token_id,
[1544](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:1544) output_scores=generation_config.output_scores,
[1545](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:1545) return_dict_in_generate=generation_config.return_dict_in_generate,
[1546](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:1546) synced_gpus=synced_gpus,
[1547](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:1547) streamer=streamer,
[1548](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:1548) **model_kwargs,
[1549](https://file+.vscode-resource.vscode-cdn.net/~/coRuntimeError: expected scalar type Float but found Halfde/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:1549) )
[1551](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:1551) elif is_contrastive_search_gen_mode:
[1552](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:1552) if generation_config.num_return_sequences > 1:
File [~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:2362](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:2362), in GenerationMixin.greedy_search(self, input_ids, logits_processor, stopping_criteria, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, synced_gpus, streamer, **model_kwargs)
[2359](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:2359) model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
[2361](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:2361) # forward pass to get next token
-> [2362](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:2362) outputs = self(
[2363](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:2363) **model_inputs,
[2364](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:2364) return_dict=True,
[2365](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:2365) output_attentions=output_attentions,
[2366](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:2366) output_hidden_states=output_hidden_states,
[2367](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:2367) )
[2369](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:2369) if synced_gpus and this_peer_finished:
[2370](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/generation/utils.py:2370) continue # don't waste resources running the code we don't need
File [~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/nn/modules/module.py:1501](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/nn/modules/module.py:1501), in Module._call_impl(self, *args, **kwargs)
[1496](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/nn/modules/module.py:1496) # If we don't have any hooks, we want to skip the rest of the logic in
[1497](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/nn/modules/module.py:1497) # this function, and just call forward.
[1498](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/nn/modules/module.py:1498) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
[1499](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/nn/modules/module.py:1499) or _global_backward_pre_hooks or _global_backward_hooks
[1500](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/nn/modules/module.py:1500) or _global_forward_hooks or _global_forward_pre_hooks):
-> [1501](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/nn/modules/module.py:1501) return forward_call(*args, **kwargs)
[1502](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/nn/modules/module.py:1502) # Do not call functions when jit is used
[1503](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/nn/modules/module.py:1503) full_backward_hooks, non_full_backward_hooks = [], []
File [~/code/LLaVA/venv_llava/lib/python3.10/site-packages/accelerate/hooks.py:165](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/accelerate/hooks.py:165), in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
[163](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/accelerate/hooks.py:163) output = old_forward(*args, **kwargs)
[164](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/accelerate/hooks.py:164) else:
--> [165](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/accelerate/hooks.py:165) output = old_forward(*args, **kwargs)
[166](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/accelerate/hooks.py:166) return module._hf_hook.post_forward(module, output)
File [~/code/LLaVA/llava/model/language_model/llava_llama.py:88](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/llava/model/language_model/llava_llama.py:88), in LlavaLlamaForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, images, return_dict)
[71](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/llava/model/language_model/llava_llama.py:71) if inputs_embeds is None:
[72](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/llava/model/language_model/llava_llama.py:72) (
[73](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/llava/model/language_model/llava_llama.py:73) input_ids,
[74](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/llava/model/language_model/llava_llama.py:74) position_ids,
(...)
[85](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/llava/model/language_model/llava_llama.py:85) images
[86](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/llava/model/language_model/llava_llama.py:86) )
---> [88](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/llava/model/language_model/llava_llama.py:88) return super().forward(
[89](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/llava/model/language_model/llava_llama.py:89) input_ids=input_ids,
[90](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/llava/model/language_model/llava_llama.py:90) attention_mask=attention_mask,
[91](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/llava/model/language_model/llava_llama.py:91) position_ids=position_ids,
[92](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/llava/model/language_model/llava_llama.py:92) past_key_values=past_key_values,
[93](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/llava/model/language_model/llava_llama.py:93) inputs_embeds=inputs_embeds,
[94](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/llava/model/language_model/llava_llama.py:94) labels=labels,
[95](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/llava/model/language_model/llava_llama.py:95) use_cache=use_cache,
[96](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/llava/model/language_model/llava_llama.py:96) output_attentions=output_attentions,
[97](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/llava/model/language_model/llava_llama.py:97) output_hidden_states=output_hidden_states,
[98](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/llava/model/language_model/llava_llama.py:98) return_dict=return_dict
[99](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/llava/model/language_model/llava_llama.py:99) )
File [~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:824](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:824), in LlamaForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)
[822](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:822) logits = torch.cat(logits, dim=-1)
[823](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:823) else:
--> [824](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:824) logits = self.lm_head(hidden_states)
[825](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:825) logits = logits.float()
[827](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:827) loss = None
File [~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/nn/modules/module.py:1501](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/nn/modules/module.py:1501), in Module._call_impl(self, *args, **kwargs)
[1496](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/nn/modules/module.py:1496) # If we don't have any hooks, we want to skip the rest of the logic in
[1497](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/nn/modules/module.py:1497) # this function, and just call forward.
[1498](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/nn/modules/module.py:1498) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
[1499](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/nn/modules/module.py:1499) or _global_backward_pre_hooks or _global_backward_hooks
[1500](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/nn/modules/module.py:1500) or _global_forward_hooks or _global_forward_pre_hooks):
-> [1501](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/nn/modules/module.py:1501) return forward_call(*args, **kwargs)
[1502](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/nn/modules/module.py:1502) # Do not call functions when jit is used
[1503](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/nn/modules/module.py:1503) full_backward_hooks, non_full_backward_hooks = [], []
File [~/code/LLaVA/venv_llava/lib/python3.10/site-packages/accelerate/hooks.py:165](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/accelerate/hooks.py:165), in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
[163](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/accelerate/hooks.py:163) output = old_forward(*args, **kwargs)
[164](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/accelerate/hooks.py:164) else:
--> [165](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/accelerate/hooks.py:165) output = old_forward(*args, **kwargs)
[166](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/accelerate/hooks.py:166) return module._hf_hook.post_forward(module, output)
File [~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/nn/modules/linear.py:114](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/nn/modules/linear.py:114), in Linear.forward(self, input)
[113](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/nn/modules/linear.py:113) def forward(self, input: Tensor) -> Tensor:
--> [114](https://file+.vscode-resource.vscode-cdn.net/~/code/LLaVA/venv_llava/lib/python3.10/site-packages/torch/nn/modules/linear.py:114) return F.linear(input, self.weight, self.bias)
RuntimeError: expected scalar type Float but found Half
Issue:
When loading the model in 4bit, I am facing this error: RuntimeError: expected scalar type Float but found Half.
Reproduce the code:
Traceback: