IntelLabs / lvlm-interpret

Apache License 2.0
36 stars 6 forks source link

RuntimeError: "triu_tril_cuda_template" not implemented for 'BFloat16' in LLAVA-Gemma Model #5

Closed Lancelottery closed 1 month ago

Lancelottery commented 1 month ago

Description

I encountered an issue while running the LLAVA-Gemma model on my local machine. The error message indicates that the "triu_tril_cuda_template" is not implemented for the 'BFloat16' data type.

Error Message

Here is the complete error trace:

python app.py --model_name_or_path intel-llava-gemma-2b --load_8bit --port 8080

/path/to/env/venv/lib/python3.11/site-packages/transformers/deepspeed.py:24: FutureWarning: transformers.deepspeed module is deprecated and will be removed in a future version. Please import deepspeed modules directly from transformers.integrations
  warnings.warn(
INFO:accelerate.utils.modeling:We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` to a higher value to use more memory (at your own risk).
Loading checkpoint shards: 100%|████████████████| 3/3 [00:05<00:00,  1.78s/it]
Running on local URL:  http://0.0.0.0:8080
INFO:httpx:HTTP Request: GET http://localhost:8080/startup-events "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: HEAD http://localhost:8080/ "HTTP/1.1 200 OK"

To create a public link, set `share=True` in `launch()`.
INFO:utils_gradio:Describe this image.
/path/to/env/venv/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
WARNING:utils_model:Attention weights were not returned for the vision model. Relevancy maps will not be calculated for the vision model. To enable, set output_attentions=True in the forward pass of vision_tower. 
Traceback (most recent call last):
  File "/path/to/env/venv/lib/python3.11/site-packages/gradio/queueing.py", line 536, in process_events
    response = await route_utils.call_process_api(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/path/to/env/venv/lib/python3.11/site-packages/gradio/route_utils.py", line 276, in call_process_api
    output = await app.get_blocks().process_api(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/path/to/env/venv/lib/python3.11/site-packages/gradio/blocks.py", line 1897, in process_api
    result = await self.call_function(
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/path/to/env/venv/lib/python3.11/site-packages/gradio/blocks.py", line 1483, in call_function
    prediction = await anyio.to_thread.run_sync(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/path/to/env/venv/lib/python3.11/site-packages/anyio/to_thread.py", line 56, in run_sync
    return await get_async_backend().run_sync_in_worker_thread(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/path/to/env/venv/lib/python3.11/site-packages/anyio/_backends/_asyncio.py", line 2177, in run_sync_in_worker_thread
    return await future
           ^^^^^^^^^^^^
  File "/path/to/env/venv/lib/python3.11/site-packages/anyio/_backends/_asyncio.py", line 859, in run
    result = context.run(func, *args)
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/path/to/env/venv/lib/python3.11/site-packages/gradio/utils.py", line 816, in wrapper
    response = f(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^
  File "/path/to/env/venv/lib/python3.11/site-packages/gradio/utils.py", line 136, in lvlm_bot
    outputs = model.generate(
              ^^^^^^^^^^^^^^^
  File "/path/to/env/venv/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/path/to/env/venv/lib/python3.11/site-packages/transformers/generation/utils.py", line 1914, in generate
    result = self._sample(
             ^^^^^^^^^^^^^
  File "/path/to/env/venv/lib/python3.11/site-packages/transformers/generation/utils.py", line 2651, in _sample
    outputs = self(
              ^^^^^
  File "/path/to/env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/path/to/env/venv/lib/python3.11/site-packages/accelerate/hooks.py", line 169, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/path/to/env/venv/lib/python3.11/site-packages/transformers/models/llava/modeling_llava.py", line 476, in forward
    outputs = self.language_model(
              ^^^^^^^^^^^^^^^^^^^^
  File "/path/to/env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/path/to/env/venv/lib/python3.11/site-packages/accelerate/hooks.py", line 169, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/path/to/env/venv/lib/python3.11/site-packages/transformers/models/gemma/modeling_gemma.py", line 1127, in forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "/path/to/env/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/path/to/env/venv/lib/python3.11/site-packages/accelerate/hooks.py", line 169, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/path/to/env/venv/lib/python3.11/site-packages/transformers/models/gemma/modeling_gemma.py", line 890, in forward
    causal_mask = self._update_causal_mask(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/path/to/env/venv/lib/python3.11/site-packages/transformers/models/gemma/modeling_gemma.py", line 1024, in _update_causal_mask
    causal_mask = torch.triu(causal_mask, diagonal=1)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: "triu_tril_cuda_template" not implemented for 'BFloat16'

Environment

Steps to Reproduce

  1. Install the specified versions of torch, torchvision, transformers, and cuda.
  2. Download the intel/llava-gemma-2b from huggingface (the entire files)
  3. Run the command: python app.py --model_name_or_path intel-llava-gemma-2b --load_8bit --port 8080

Additional Information

p.s. Do I need to worry about the "Attention weights were not returned for the vision mode" warning? I was using the Intel model, so I thought everything should be set as default?


Thank you so much for the excellent work! I am looking forward to the reply!

Lancelottery commented 1 month ago

I fixed this by manually setting torch_dtype=torch.float16 in utils_model.py. I also fixed the attention weights warning by wrapping the model and setting the output_attentions=True.

def forward_with_attentions(module):
    original_forward = module.forward

    def new_forward(*args, **kwargs):
        kwargs['output_attentions'] = True
        return original_forward(*args, **kwargs)

    module.forward = new_forward

def get_processor_model(args):
    processor = AutoProcessor.from_pretrained(args.model_name_or_path)

    if args.load_4bit:
        quant_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_use_double_quant=True,
            bnb_4bit_compute_dtype=torch.float16
        )
    elif args.load_8bit:
        quant_config = BitsAndBytesConfig(
            load_in_8bit=True
        )
    else:
        quant_config = None

    model = LlavaForConditionalGeneration.from_pretrained(
        args.model_name_or_path, torch_dtype=torch.float16, 
        quantization_config=quant_config, low_cpu_mem_usage=True, device_map="auto"
    )

    # Wrap vision model's forward function
    for layer in model.vision_tower.vision_model.encoder.layers:
        forward_with_attentions(layer.self_attn)

     .....