fpgaminer / joycaption

JoyCaption is an image captioning Visual Language Model (VLM) being built from the ground up as a free, open, and uncensored model for the community to use in training Diffusion models.
Apache License 2.0
146 stars 1 forks source link

How to run with BNB 4bit or 8bit quantization? #3

Open fireicewolf opened 1 month ago

fireicewolf commented 1 month ago

I tryed to modify your example code to run this model on lowvram card by BNB 4bit or 8bit quantization config.

While use bnb 4bit config like below:

qnt_config = BitsAndBytesConfig(load_in_4bit=True,
                                bnb_4bit_quant_type="nf4",
                                bnb_4bit_compute_dtype=torch.float16,
                                bnb_4bit_use_double_quant=True)

First time this issue occured while pixel_values = pixel_values.to(torch.bfloat16).unsqueeze(0) RuntimeError: Input type (CUDABFloat16Type) and weight type (torch.cuda.HalfTensor) should be the same Then I changed it to pixel_values = pixel_values.to(llm_dtype).unsqueeze(0)(llm_dtype is llava models weight load dtype) RuntimeError: self and mat2 must have the same dtype, but got Half and Byte

these error should be caused by image input dtype.

Any idea to make it works?

fpgaminer commented 1 month ago

Keep pixel_values = pixel_values.to(torch.bfloat16).unsqueeze(0), but maybe try changing bnb_4bit_compute_dtype=torch.float16 to bnb_4bit_compute_dtype=torch.bfloat16?

fireicewolf commented 1 month ago

Keep pixel_values = pixel_values.to(torch.bfloat16).unsqueeze(0), but maybe try changing bnb_4bit_compute_dtype=torch.float16 to bnb_4bit_compute_dtype=torch.bfloat16?

keep bnb_4bit_compute_dtype=torch.bfloat16 same topixel_values.to(torch.bfloat16).unsqueeze(0) will cause this error:

Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/gradio/queueing.py", line 622, in process_events
    response = await route_utils.call_process_api(
  File "/opt/conda/lib/python3.10/site-packages/gradio/route_utils.py", line 323, in call_process_api
    output = await app.get_blocks().process_api(
  File "/opt/conda/lib/python3.10/site-packages/gradio/blocks.py", line 2014, in process_api
    result = await self.call_function(
  File "/opt/conda/lib/python3.10/site-packages/gradio/blocks.py", line 1567, in call_function
    prediction = await anyio.to_thread.run_sync(  # type: ignore
  File "/opt/conda/lib/python3.10/site-packages/anyio/to_thread.py", line 56, in run_sync
    return await get_async_backend().run_sync_in_worker_thread(
  File "/opt/conda/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 2177, in run_sync_in_worker_thread
    return await future
  File "/opt/conda/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 859, in run
    result = context.run(func, *args)
  File "/opt/conda/lib/python3.10/site-packages/gradio/utils.py", line 846, in wrapper
    response = f(*args, **kwargs)
  File "/root/private_data/wd-joy-caption-cli/wd_llm_caption/gui.py", line 713, in caption_single_inference
    caption_text = get_caption_fn.my_llm.get_caption(
  File "/root/private_data/wd-joy-caption-cli/wd_llm_caption/utils/inference.py", line 607, in get_caption
    self.llm.generate(input_ids=input_ids, pixel_values=pixel_values,
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py", line 2047, in generate
    result = self._sample(
  File "/opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py", line 3007, in _sample
    outputs = self(**model_inputs, return_dict=True)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/llava/modeling_llava.py", line 453, in forward
    image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/siglip/modeling_siglip.py", line 1189, in forward
    return self.vision_model(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/siglip/modeling_siglip.py", line 1100, in forward
    pooler_output = self.head(last_hidden_state) if self.use_head else None
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/siglip/modeling_siglip.py", line 1127, in forward
    hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/activation.py", line 1241, in forward
    attn_output, attn_output_weights = F.multi_head_attention_forward(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/functional.py", line 5430, in multi_head_attention_forward
    attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
RuntimeError: self and mat2 must have the same dtype, but got BFloat16 and Byte

if use bnb 8bit config:

qnt_config = BitsAndBytesConfig(load_in_8bit=True,
                                llm_int8_enable_fp32_cpu_offload=True)

will cause this error:

/opt/conda/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:324: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/gradio/queueing.py", line 622, in process_events
    response = await route_utils.call_process_api(
  File "/opt/conda/lib/python3.10/site-packages/gradio/route_utils.py", line 323, in call_process_api
    output = await app.get_blocks().process_api(
  File "/opt/conda/lib/python3.10/site-packages/gradio/blocks.py", line 2014, in process_api
    result = await self.call_function(
  File "/opt/conda/lib/python3.10/site-packages/gradio/blocks.py", line 1567, in call_function
    prediction = await anyio.to_thread.run_sync(  # type: ignore
  File "/opt/conda/lib/python3.10/site-packages/anyio/to_thread.py", line 56, in run_sync
    return await get_async_backend().run_sync_in_worker_thread(
  File "/opt/conda/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 2177, in run_sync_in_worker_thread
    return await future
  File "/opt/conda/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 859, in run
    result = context.run(func, *args)
  File "/opt/conda/lib/python3.10/site-packages/gradio/utils.py", line 846, in wrapper
    response = f(*args, **kwargs)
  File "/root/private_data/wd-joy-caption-cli/wd_llm_caption/gui.py", line 713, in caption_single_inference
    caption_text = get_caption_fn.my_llm.get_caption(
  File "/root/private_data/wd-joy-caption-cli/wd_llm_caption/utils/inference.py", line 607, in get_caption
    self.llm.generate(input_ids=input_ids, pixel_values=pixel_values,
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py", line 2047, in generate
    result = self._sample(
  File "/opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py", line 3007, in _sample
    outputs = self(**model_inputs, return_dict=True)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/llava/modeling_llava.py", line 453, in forward
    image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/siglip/modeling_siglip.py", line 1189, in forward
    return self.vision_model(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/siglip/modeling_siglip.py", line 1100, in forward
    pooler_output = self.head(last_hidden_state) if self.use_head else None
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/siglip/modeling_siglip.py", line 1127, in forward
    hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/activation.py", line 1241, in forward
    attn_output, attn_output_weights = F.multi_head_attention_forward(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/functional.py", line 5430, in multi_head_attention_forward
    attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
RuntimeError: self and mat2 must have the same dtype, but got BFloat16 and Char
fpgaminer commented 1 month ago

Looks like this is a bug in transformers. I've submitted a bug report and pull request to get it fixed: https://github.com/huggingface/transformers/issues/34294

We'll have to wait for that to get fixed and a new version of transformers released to have a clean fix here.

fireicewolf commented 1 month ago

Thanks for your help, let's wait hf response.

Tablaski commented 3 weeks ago

Interested as well, could never run NF4 and original model is wayy to slow on my setup :-(

Tablaski commented 1 week ago

@fpgaminer some news ? I've read the reply from the transformer github and tried their solution but it didn't change anything

effusiveperiscope commented 1 week ago

https://github.com/bitsandbytes-foundation/bitsandbytes/issues/963 seems to be related. I tried with the current transformers github as of today. It looks like in modeling_siglip.py self.attention.out_proj.weight has a uint8 type after 4-bit quantization, and the multihead attention calculation fails with the same error. I don't know enough about quantization to know whether this is correct behavior.