turboderp / exllamav2

A fast inference library for running LLMs locally on modern consumer-class GPUs
MIT License
3.18k stars 233 forks source link

Input the embedding tensor into LLMs? #399

Open aliencaocao opened 2 months ago

aliencaocao commented 2 months ago

If I want to work with multimodal LLMs that takes in a set of embedding from vision/audio encoders, what is the proper way of inputting them into a LLM running using exllamav2? Can I just add a custom if/else here at https://github.com/turboderp/exllamav2/blob/f6b7faa429080cd7c7e394ec301442fbf137658f/exllamav2/embedding.py#L79 to decide if I want the text embedding layer to run? The initial embedding coming in can be either multimode part alone or concat with text token embeddings.

turboderp commented 2 months ago

ExLlama defines the forward pass as a list of modules, each of which takes an input tensor and returns an output tensor. At the beginning of the forward pass the shape of this tensor is assumed to be (batch_size, seq_len), but after that point it's really defined by each individual architecture how the shape and datatype of the hidden state changes over the forward pass.

I think I would probably prefer to add a single dictionary argument to ExLlamaV2.forward which can be passed along to each module (layer) to carry extra information like input embeddings, using pseudotokens as an index. ExLlamaV2Embedding.forward would then build the first hidden state either from the model's own embedding table or from the provided input embeddings or both.

The idea would be that, say you're providing 150 input embeddings, these could be represented as token IDs 1000000 to 1000149 (or something, token IDs are 64 bits so range isn't an issue), allowing you to mix text/image/audio data arbitrarily in the context.

This would also integrate with the existing generator and cache systems fairly easily, with some small changes to prevent the tokenizer from trying to decode the pseudotokens.

One other small detail is that ExLlama keeps the embedding table in system RAM, where the initial hidden state is also built before being moved to the first CUDA device. Maybe that could be reworked a bit to more efficiently mix in the audio/vision embeddings from another model computing them in VRAM.

I can sketch up a rough version of it and push it to the dev branch in a few hours.

aliencaocao commented 2 months ago

That would be very helpful! For context, I am hoping to use llava with exllamav2 here, the llm will be mistral 7b v0.2.

turboderp commented 2 months ago

Looks like it will take a little longer. Probably can't get to it until tomorrow. But in the meantime, would it be necessary to mask out the extra embeddings to avoid applying position embeddings to them?

aliencaocao commented 2 months ago

Yes. What I am thinking is just to borrow HF's impl for initial embeddings (image + prompt tokens embeddings from a 1-time call to LM's forward with token IDs), then pass this tensor into EXL2's Embedding module's forward as hidden_states. I think this will be the fastest way. In HF's impl, the RotaryEmbedding is only applied in the forward of the language model, here https://github.com/huggingface/transformers/blob/240e10626b10574899ecd9a3ddcc47788f289732/src/transformers/models/mistral/modeling_mistral.py#L275 This leads to me thinking that it only applies to the prompt text tokens but not the image embeddings. So I do think it is necessary to mask out them.

One thing I am not sure is, will EXL2 be able to get the text embeddings for input prompt tokens alone then concat to the image embeddings that will be passed in at the same forward call? This is currently done in HF in https://github.com/huggingface/transformers/blob/240e10626b10574899ecd9a3ddcc47788f289732/src/transformers/models/llava_next/modeling_llava_next.py#L356 That needs to be done so that we dont have to call any forward of the LM using HF, and instead can supply the token ids directly to EXL2 for the embeddings then concat with the image embeddings.

aliencaocao commented 2 months ago

I think a cleaner way will be to provide an additional arg that is exclusively for passing in non-text embeddings like image/audio then replacing a special token in the text input ids (llava uses <image>), and then mask the index of these embeddings from position embedding. Instead of passing in a mixture of multimodal embedding and text token ids.

turboderp commented 2 months ago

I'll see if I can work out some elegant way of not adding position embeddings to the image data. And yes, what I'm imagining is pretty much that. But I think it would have to be something like:

offset = 1000000

image_embeddings = clip_model.whatever(...)  # shape (batch_size, n_img_tokens, hidden_size)
text_ids = tokenizer.encode("<- what is this?")  # shape (batch_size, n_text_tokens)

image_ids = torch.arange(offset, offset + n_img_tokens).unsqueeze(1)
input_ids = torch.cat((image_ids, text_ids), dim = -1)

logits = model.forward(input_ids, extra_params = { "indexed_embeddings": image_embeddings })

The reason is that then most of the assumptions made elsewhere in the framework will still hold. I.e. the cache space occupied is reflected in the length of the input sequence, the generator's sequence IDs will align with the cache entries, cache reuse will still work, etc. And it will be possible to embed multiple images or other types of embeddings if necessary, interleaved with text in an arbitrary fashion.

And of course with a little tweak to the tokenizer to prevent any attempts to decode those index tokens, it would also work seamlessly with the generators, as in

...
generator.begin_stream_ex(input_ids)
while True:
    stream_result = generator.stream()
    print (stream_result["chunk"], end = "")
    if stream_result["eos"]: break
aliencaocao commented 2 months ago

Do you have a more specific idea of how should this be implemented? I have gotten the input embedding into model.forward but it seem to output the same dimension of logits as I inputted. The input embedding is of shape (1, 2629, 4096) and logit is same. I am confused about this because I thought it should only give 1 token a time. However, if i set config.max_output_len=1 then it complains about input exceeds max len. Decoding the logit gives close to nonsense although some tokens looks correct, just that its filled with gibberish in between them. Do you have any idea where did I do wrong?

turboderp commented 2 months ago

I have it more or less implemented. But then I got a little sidetracked by cmdr+. I'll get back to it very soon. Masking out the position embeddings requires rewriting the RoPE kernel, but other than that it's pretty much done.

aliencaocao commented 2 months ago

Thanks. Do you mean that with existing python codes i cannot prevent the application of Rope on input embeddings?

turboderp commented 2 months ago

No, the CUDA kernels can support an offset but not a mask.

I've just pushed the cmdr+ changes to the dev branch, and there's also the unfinished and untested indexed embeddings stuff in there. Here: https://github.com/turboderp/exllamav2/commit/0375443142212f808aa14f1600aebbc8f4b45ffc

Next I'll add position IDs, I think, since it'll the most flexible way to control the embeddings. I'll make it use any negative position as a mask to indicate tokens that shouldn't have RoPE applied.

Although, I'm wondering if the image tokens really shouldn't have any embeddings applied, or if there is something else involved to, say, allow multiple images in a context?

aliencaocao commented 2 months ago

The CLIP encoder applies positional embedding to image tokens already so I think it will duplicate. https://github.com/huggingface/transformers/blob/76fa17c1663a0efeca7208c20579833365584889/src/transformers/models/clip/modeling_clip.py#L193

In HF's code however, https://github.com/huggingface/transformers/blob/76fa17c1663a0efeca7208c20579833365584889/src/transformers/models/llava_next/modeling_llava_next.py#L596 they do pass the entire input embedding into the apply rotation embedding func but they have position_ids to track. I dont know if EXL2 have something similar. Their position_ids is from https://github.com/huggingface/transformers/blob/76fa17c1663a0efeca7208c20579833365584889/src/transformers/models/llava_next/modeling_llava_next.py#L418

I do have a half-working patch that just pass in the input embedding and it does seem to work but extremely unreliable. At temp=0 it give basically none sense and at temp=1 it give sensible (as good as raw HF impl) about 10% of the time. I dont know if this randomness is to be expected tbh because if it is due to the double-application of positional embedding then it should not be random.

turboderp commented 2 months ago

If it fails consistently at temp=0 but works 10% with sampling, that could sound like double embeddings are the issue. But it's also still untested. It took a lot longer than I wanted it to to get cmdr+ to convert, but I'm back on it now and I'll try to actually test it to see if I maybe screwed something up. And of course get the position IDs working.

aliencaocao commented 2 months ago

Oh I am not using your latest commits, Im just modifying based on 0.0.17 to add input_embedding in functions and just passing it around. On a closer look, I think that masking may not be needed. HF's position_id that gets passed into the model is simply 1 to 2000 increasing linearly, so there is no "masking" here like setting the image tokens' positions to the same. This same position id will get processed by HF's language model codes and applied positional embeddings again.

I think the problem might be that I am passing in the text tokens for a forward of embedding layer directly and it may have applied the (wrong) position embeddings to my text tokens. Does your embedding layer apply positional embedding already?

This is how I am forwarding the text tokens now:

prompt = 'describe this image'
input_ids = tokenizer.encode(prompt, add_bos=True, encode_special_tokens=True).to('cuda')
embedding_layer = next(m for m in model.modules if isinstance(m, ExLlamaV2Embedding))
embedding_layer.embedding.to('cuda')
inputs_embeds = embedding_layer.forward(hidden_states=input_ids)

input_embeds is then being cocat with image embeddings and passed into LLM's forward again for first token prediction.

turboderp commented 2 months ago

The embedding layer doesn't apply any position embeddings on its own, it only looks up the relevant embedding vectors and concatenates them to produce the initial hidden state. The subsequent attention layers then do the RoPE starting from position zero, on whatever embeddings come out of the first layer.

How are you sending input_embeds to the forward pass, though? You'll have to skip the first module in model.modules or it will just run the embedding layer again. Also, the forward pass is chunked, so if there are too many image tokens it could lead to some funny business.

aliencaocao commented 2 months ago

Here are all my patches so far: https://github.com/aliencaocao/exllamav2/commit/e393554ce4254eb4a6dd2653170130abf65e2377

I skipped the embedding module here https://github.com/aliencaocao/exllamav2/commit/e393554ce4254eb4a6dd2653170130abf65e2377#diff-782674f111f12011bffae70945860194a3dd1241273bd285c728587edad4fe3bR757

The image tokens + prompt are 2652 tokens (input emb shape is (1, 2652, 4096), 4096 is the image encoder dim)

One notable change i did was to not call self._gen_begin_base on input embedding as it fills the cache prematurely causing it to complain afterwards when I call forward again. Judging from the code I dont think this could have affected.

Configs:

config.max_input_len = 4096  # if i dont set this it complain about max token exceed etc
config.max_output_len = 4096
settings.temperature = 0.9  # also tried 0.0 but constantly fail. lower value didn't make it more stable
settings.top_p = 0.9
# settings.disallow_tokens(tokenizer, [tokenizer.eos_token_id])   # according to lmformatenforcer example this is needed but it just goes on with non-sense
# settings.filters = [ExLlamaV2TokenEnforcerFilter(regex_formatter, tokenizer)]  # didn't really work as it keep saying no valid tokens left, an effect of the LM outputting gibberish

How may the chunking affect?

turboderp commented 2 months ago

If you don't also set config.max_attention_size = 4096**2 it will still chunk to the default chunk size of 2048. You'll then have to also slice the input_embeds tensor as it's passed from model.forward to model._forward.

... which I see you're doing. Still a little hard to tell if everything is slicing correctly though. Looking over the code I can't really tell if the embeddings are being sent to the forward pass at all at any point?

aliencaocao commented 2 months ago

It still does seem to me to be something related to positional embedding. temp=0 will give basically a repetition of my prompt (i gave 2 examples and it just repeated 1). Last time I had this issue on HF impl was due to a wrongly set rope theta. The model I use uses 1,000,000 (mistral 7b v0.2). When I converted it to exl 8bit, I also didn't pass in any special args regarding rope.

python exllamav2/convert.py \
    -i models/llava-v1.6-mistral-7b-hf-llm/ \
    -o models/exl2_tmp/ \
    -cf models/llava-v1.6-mistral-7b-hf-llm-exl2-8bit/ \
    -l 4096 \
    -b 8.0 \
    -hb 8

Could it be that the rope code isn't reading the config properly, or I have converted it wrongly?

Looking over the code I can't really tell if the embeddings are being sent to the forward pass at all at any point

I think it is being sent because the response generated does clear relate to the image, it just tend to mention some keywords in the image then the rest become gibberish or repetitions

turboderp commented 2 months ago

The position embeddings are based on the length of the cache, and a number of other things have to be kept aligned, which is why it's a bit tricky to sidestep everything like that.

I did finish the indexed embedding stuff though, and I added support in the base generator. Haven't tested it with actual image tokens, but it seems like it's inserting them correctly now. And because they're essentially treated like a second, temporary vocabulary, the generator and chunking logic should be indifferent to them, as should the RoPE stuff.

With the latest commit you should be able to do:

image_emb = torch.randn((1, 420, 4096), dtype = torch.half, device = "cuda:0")
image_ids = torch.arange(EMBEDDING_INDEX, EMBEDDING_INDEX + image_emb.shape[1], 
                         dtype = torch.long).unsqueeze(0)

text_ids = tokenizer.encode("Describe the image")
input_ids = torch.cat((image_ids, text_ids), dim = -1)

logits = model.forward(input_ids, indexed_embeddings = image_emb)

Or, using the generator:

image_emb = torch.randn((1, 420, 4096), dtype = torch.half, device = "cuda:0")
prompt = "Describe the image:"
output = generator.generate_simple(prompt, settings, max_out_tokens, input_embeddings = image_emb)
print (output)
aliencaocao commented 2 months ago

My bad if you received notifs for deleted comments - realised i forgot to recompile the C extension after switching branch.

I just tested out and it is about as random as my original patch.

if I shorten the prompt, it just give this: image

aliencaocao commented 2 months ago

Actually it is still slightly difference. I need the prompt to be this [INST] <image>\nDescribe this image [/INST] format, which means gotta insert the image tensors in between the text embeddings (and as a special token) HF does it in this function: https://github.com/huggingface/transformers/blob/76fa17c1663a0efeca7208c20579833365584889/src/transformers/models/llava_next/modeling_llava_next.py#L356

But prior to your patch i did use this function to merge them and it still didn't work properly, so I guess it could be a nice-to-have but not the primary cause here

turboderp commented 2 months ago

Well, I changed the implementation in generate_simple to take a "{{EMBED_HERE}}" marker in the prompt to control where the embeddings are inserted. Try if that helps?

Do you have a complete example for how you're generating the embeddings and a link to the relevant models? It would probably help if I could debug it side by side with the transformers reference implementation instead of trying to parse it all in my smooth brain.

aliencaocao commented 2 months ago
{
  "architectures": [
    "LlavaNextForConditionalGeneration"
  ],
  "ignore_index": -100,
  "image_grid_pinpoints": [
    [
      336,
      672
    ],
    [
      672,
      336
    ],
    [
      672,
      672
    ],
    [
      1008,
      336
    ],
    [
      336,
      1008
    ]
  ],
  "image_token_index": 32000,
  "model_type": "llava_next",
  "projector_hidden_act": "gelu",
  "text_config": {
    "_name_or_path": "mistralai/Mistral-7B-Instruct-v0.2",
    "architectures": [
      "MistralForCausalLM"
    ],
    "intermediate_size": 14336,
    "max_position_embeddings": 32768,
    "model_type": "mistral",
    "num_key_value_heads": 8,
    "rms_norm_eps": 1e-05,
    "rope_theta": 1000000.0,
    "sliding_window": null,
    "torch_dtype": "bfloat16",
    "vocab_size": 32064
  },
  "torch_dtype": "float16",
  "transformers_version": "4.39.0.dev0",
  "use_image_newline_parameter": true,
  "vision_config": {
    "hidden_size": 1024,
    "image_size": 336,
    "intermediate_size": 4096,
    "model_type": "clip_vision_model",
    "num_attention_heads": 16,
    "num_hidden_layers": 24,
    "patch_size": 14,
    "projection_dim": 768,
    "vocab_size": 32000
  },
  "vision_feature_layer": -2,
  "vision_feature_select_strategy": "default",
  "vocab_size": 32064
}

Copy the config above into models/llava-v1.6-mistral-7b-hf-config/config.json.

import sys
import sys
import traceback
from base64 import b64decode
from contextlib import contextmanager
from functools import partial
from io import BytesIO
from typing import Union, NamedTuple

import fastapi
import safetensors
import torch
from PIL import Image
from lmformatenforcer import RegexParser
from lmformatenforcer.integrations.exllamav2 import ExLlamaV2TokenEnforcerFilter
from pydantic import BaseModel
from torch import nn

import transformers
from exllamav2 import ExLlamaV2, ExLlamaV2Cache, ExLlamaV2Config, ExLlamaV2Tokenizer
from exllamav2.embedding import ExLlamaV2Embedding
from exllamav2.generator import ExLlamaV2BaseGenerator, ExLlamaV2Sampler
from transformers import CLIPVisionModel, LlavaNextConfig, LlavaNextProcessor
from transformers.models.llava_next.modeling_llava_next import LlavaNextMultiModalProjector, get_anyres_image_grid_shape, unpad_image

assert torch.cuda.is_available()
device = "cuda"

def merge_input_ids_with_image_features(image_features, inputs_embeds, input_ids, attention_mask, image_token_index=32000, pad_token_id=0):
    """Copied and simplified from HF"""
    num_images, num_image_patches, embed_dim = image_features.shape
    batch_size, sequence_length = input_ids.shape
    left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(pad_token_id))
    # 1. Create a mask to know where special image tokens are
    special_image_token_mask = input_ids == image_token_index
    num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
    # Compute the maximum embed dimension
    max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
    batch_indices, non_image_indices = torch.where(input_ids != image_token_index)

    # 2. Compute the positions where text should be written
    # Calculate new positions for text tokens in merged image-text sequence.
    # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
    # `torch.cumsum` computes how each image token shifts subsequent text token positions.
    # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
    new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
    nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
    if left_padding:
        new_token_positions += nb_image_pad[:, None]  # offset for left padding
    text_to_overwrite = new_token_positions[batch_indices, non_image_indices]

    # 3. Create the full embedding, already padded to the maximum position
    final_embedding = torch.zeros(
        batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
    )
    final_attention_mask = torch.zeros(
        batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
    )
    # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
    # set the corresponding tensors into their correct target device.
    target_device = inputs_embeds.device
    batch_indices, non_image_indices, text_to_overwrite = (
        batch_indices.to(target_device),
        non_image_indices.to(target_device),
        text_to_overwrite.to(target_device),
    )
    attention_mask = attention_mask.to(target_device)

    # 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
    # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
    final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
    final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
    # 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling
    image_to_overwrite = torch.all(final_embedding == 0, dim=-1)
    image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)

    if image_to_overwrite.sum() != image_features.shape[:-1].numel():
        raise ValueError(
            f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
            f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
        )

    final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
    final_attention_mask |= image_to_overwrite
    position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)

    # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
    batch_indices, pad_indices = torch.where(input_ids == pad_token_id)
    indices_to_mask = new_token_positions[batch_indices, pad_indices]

    final_embedding[batch_indices, indices_to_mask] = 0

    return final_embedding, final_attention_mask, position_ids

batch_size = 1  # have to set here instead of in main.py as vram allocation is done here
model_path = "models/llava-v1.6-mistral-7b-hf-llm-exl2-8bit/"  # Download separately here: https://huggingface.co/panoyo9829/llava-v1.6-mistral-7b-hf-llm-exl2-8bit
prompt_prefix = "[INST] <image>\n"
prompt_postfix = " [/INST]"

config = ExLlamaV2Config(model_path)
top_p = 0.5
temperature = 0.0
max_new_tokens = 512
config.max_input_len = 4096
config.max_output_len = 4096
config.max_batch_size = batch_size
model = ExLlamaV2(config)
cache = ExLlamaV2Cache(model, lazy=True, batch_size=batch_size)  # Cache needs to accommodate the batch size
model.load_autosplit(cache)
tokenizer = ExLlamaV2Tokenizer(config)
generator = ExLlamaV2BaseGenerator(model, cache, tokenizer)
settings = ExLlamaV2Sampler.Settings()
settings.temperature = temperature
settings.top_p = top_p
settings.disallow_tokens(tokenizer, [tokenizer.eos_token_id])
settings.filters = [ExLlamaV2TokenEnforcerFilter(regex_formatter, tokenizer)]

image_processor = LlavaNextProcessor.from_pretrained(model_path, local_files_only=True).image_processor
url = "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true"
image = Image.open(requests.get(url, stream=True).raw)
prompt = "[INST] <image>\nWhat is shown in this image? [/INST]"

input_ids = tokenizer.encode(prompt_prefix + prompt + prompt_postfix, add_bos=True, encode_special_tokens=True).to(device)  # need encode_special_tokens to encode <image> into 32000
attention_mask = torch.ones_like(input_ids)
image_input_dict = image_processor(image, return_tensors="pt")

# UNCOMMENT 3 lines below if using my own fork. Comment them if using your dev branch
# embedding_layer: ExLlamaV2Embedding = next(m for m in model.modules if isinstance(m, ExLlamaV2Embedding))
# embedding_layer.embedding.to(device)
# inputs_embeds = embedding_layer.forward(hidden_states=input_ids)

pixel_values, image_sizes = image_input_dict["pixel_values"], image_input_dict["image_sizes"]
batch_size, num_patches, num_channels, height, width = pixel_values.shape
reshaped_pixel_values = pixel_values.view(batch_size * num_patches, num_channels, height, width)  # (NUM_PATCHES, 3, 336, 336)

HF_config = LlavaNextConfig.from_pretrained("models/llava-v1.6-mistral-7b-hf-config")  # create this dir and copy the json config file above. Only config.json is needed in this folder.
merge_input_ids_with_image_features = partial(merge_input_ids_with_image_features, image_token_index=HF_config.image_token_index, pad_token_id=0)
vision_tower = CLIPVisionModel.from_pretrained("models/llava-v1.6-mistral-7b-CLIP", torch_dtype=torch.float16).to(device)  # download here: https://huggingface.co/panoyo9829/llava-v1.6-mistral-7b-CLIP
multi_modal_projector = LlavaNextMultiModalProjector(HF_config)
safetensors.torch.load_model(multi_modal_projector, "models/llava-v1.6-mistral-7b-CLIP/llava-v1.6-mistral-7b-PROJ.safetensors")
multi_modal_projector.half().to(device)
image_newline = nn.Parameter(torch.empty(HF_config.text_config.hidden_size, dtype=torch.float16)).to(device)

# below are simplified from HF impl of llava-next forward
image_features = vision_tower(reshaped_pixel_values.to(device), output_hidden_states=True)
image_features = image_features.hidden_states[HF_config.vision_feature_layer][:, 1:]
image_features = multi_modal_projector(image_features)

# split up image_features for each of the individual images
# hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
# if we assume each image has 5 image features (base image + 4 patches)
split_sizes = [image.shape[0] for image in pixel_values]
image_features = torch.split(image_features, split_sizes, dim=0)
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
height = width = HF_config.vision_config.image_size // HF_config.vision_config.patch_size
new_image_features = []

for image_idx, image_feature in enumerate(image_features):
    if image_feature.shape[0] > 1:
        base_image_feature = image_feature[0]
        image_feature = image_feature[1:]

        if height * width != base_image_feature.shape[0]:
            raise ValueError("The number of patches is not consistent with the image size.")
        num_patch_height, num_patch_width = get_anyres_image_grid_shape(
            image_sizes[image_idx],
            HF_config.image_grid_pinpoints,
            HF_config.vision_config.image_size,
        )
        image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
        image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
        image_feature = image_feature.flatten(1, 2).flatten(2, 3)
        image_feature = unpad_image(image_feature, image_sizes[image_idx])
        image_feature = torch.cat(
            (
                image_feature,
                image_newline[:, None, None].expand(*image_feature.shape[:-1], 1),
            ),
            dim=-1,
        )
        image_feature = image_feature.flatten(1, 2).transpose(0, 1)
        image_feature = torch.cat((base_image_feature, image_feature), dim=0)
    else:
        image_feature = image_feature[0]
        image_feature = torch.cat((image_feature, image_newline[None]), dim=0)
    new_image_features.append(image_feature)
image_features = torch.stack(new_image_features, dim=0)
# UNCOMMENT below if using my fork,  comment if using your dev branch
# inputs_embeds, attention_mask, position_ids = merge_input_ids_with_image_features(
#     image_features, inputs_embeds, input_ids, attention_mask
# )
output = generator.generate_simple(prompt, settings, max_new_tokens, input_embeddings=image_features, add_bos=True, seed=1234)
# below for my own fork. Above for your dev branch
# outputs, prob_list = generator.generate_simple(prompt=prompt, input_embedding=inputs_embeds, gen_settings=settings, num_tokens=max_new_tokens, seed=1234)

I have commented where model downloads are required, and also where to uncomment if using my own fork.

Have been testing on torch 2.2.2+cu121.

Thanks

aliencaocao commented 2 months ago

Well, I changed the implementation in generate_simple to take a "{{EMBED_HERE}}" marker in the prompt to control where the embeddings are inserted. Try if that helps?

doesn't seem to make any difference for me. Like the output does change a bit but generally still gibberish

turboderp commented 2 months ago

Okay, so I did a little experimenting, and I've got something that at least works now.

from transformers import LlavaNextForConditionalGeneration, LlavaNextProcessor
from transformers.models.llava_next.modeling_llava_next import get_anyres_image_grid_shape, unpad_image
import torch
from PIL import Image
import requests

# Get input image

url = "https://media.istockphoto.com/id/1361394182/photo/funny-british-shorthair-cat-portrait-looking-shocked-or-surprised.jpg?s=612x612&w=0&k=20&c=6yvVxdufrNvkmc50nCLCd8OFGhoJd6vPTNotl90L-vo="
image = Image.open(requests.get(url, stream = True).raw)

# Preprocess image, we only need the patches

preprocessor = LlavaNextProcessor.from_pretrained("/mnt/str/models/llava-v1.6-mistral-7b/")
inputs = preprocessor("", image, return_tensors = "pt")

pixel_values = inputs["pixel_values"].to("cuda:0")
image_sizes = inputs["image_sizes"]

# Load Llava model, we'll only use the vision tower

llavamodel = LlavaNextForConditionalGeneration.from_pretrained(
    "/mnt/str/models/llava-v1.6-mistral-7b/",
    torch_dtype = torch.float16,
    low_cpu_mem_usage = True
).to("cuda:0")

# Get image features

batch_size, num_patches, num_channels, height, width = pixel_values.shape
assert batch_size == 1  # Just testing
reshaped_pixel_values = pixel_values.view(batch_size * num_patches, num_channels, height, width)

image_features = llavamodel.vision_tower(reshaped_pixel_values, output_hidden_states = True)

selected_image_feature = image_features.hidden_states[llavamodel.config.vision_feature_layer]
if llavamodel.config.vision_feature_select_strategy == "default":
    selected_image_feature = selected_image_feature[:, 1:]

image_features = llavamodel.multi_modal_projector(selected_image_feature)

# Split up image features

split_sizes = [image.shape[0] for image in pixel_values]
image_features = torch.split(image_features, split_sizes, dim = 0)
assert len(image_features) == 1  # Just testing
height = width = llavamodel.config.vision_config.image_size // llavamodel.config.vision_config.patch_size

new_image_features = []
for image_idx, image_feature in enumerate(image_features):
    if image_feature.shape[0] > 1:
        base_image_feature = image_feature[0]
        image_feature = image_feature[1:]

        if height * width != base_image_feature.shape[0]:
            raise ValueError("The number of patches is not consistent with the image size.")
        num_patch_height, num_patch_width = get_anyres_image_grid_shape(
            image_sizes[image_idx],
            llavamodel.config.image_grid_pinpoints,
            llavamodel.config.vision_config.image_size,
        )
        image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
        image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
        image_feature = image_feature.flatten(1, 2).flatten(2, 3)
        image_feature = unpad_image(image_feature, image_sizes[image_idx])
        image_feature = torch.cat(
            (
                image_feature,
                llavamodel.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1),
            ),
            dim = -1,
        )
        image_feature = image_feature.flatten(1, 2).transpose(0, 1)
        image_feature = torch.cat((base_image_feature, image_feature), dim = 0)
    else:
        image_feature = image_feature[0]
        image_feature = torch.cat((image_feature, llavamodel.image_newline[None]), dim = 0)
    new_image_features.append(image_feature)

image_features = torch.stack(new_image_features, dim = 0)

# EXL2 model

from exllamav2 import ExLlamaV2, ExLlamaV2Config, ExLlamaV2Cache, ExLlamaV2Tokenizer
from exllamav2.generator import ExLlamaV2BaseGenerator, ExLlamaV2Sampler

model_directory = "/mnt/str/models/mistral-7b-instruct-exl2/5.0bpw"
config = ExLlamaV2Config(model_directory)
model = ExLlamaV2(config)
cache = ExLlamaV2Cache(model, lazy = True)
model.load_autosplit(cache)
tokenizer = ExLlamaV2Tokenizer(config)
generator = ExLlamaV2BaseGenerator(model, cache, tokenizer)
gen_settings = ExLlamaV2Sampler.Settings()
gen_settings.top_k = 1

# Generate

prompt = "[INST] {{EMBED_HERE}}\nWhat is shown in this image? [/INST]"

output = generator.generate_simple(
    prompt,
    gen_settings,
    100,
    input_embeddings = image_features,
    add_bos = True
)

print (output)

The code loads llava-v1.6-mistral-7b-hf as a HF model and uses the just embedded vision tower and multimodal projector to get the image features.

ExLlama then does the inference, using a quantized version of Mistral-instruct-v0.1, which seems to work alright even though the Llava model is made from Mistral-instruct-v0.2.

Next step I guess would be to extract the relevant parts from the Llava model so it doesn't have to load a whole redundant copy of the LM.

image

[INST] What is shown in this image? [/INST] The image shows an orange cat with its mouth open, displaying its teeth. The cat's fur is gray and it appears to be looking directly at the camera.

image

[INST] What is shown in this image? [/INST] This image shows a yellow and black cartoon character with a wide smile, wearing blue jeans and a red shirt. The character has two large black eyes and a black nose, and its mouth is open wide. The character's hair is black and styled in a playful way, with a few strands sticking out from the top of its head. The background is white, and there are no other objects or people visible in the image.

aliencaocao commented 2 months ago

Next step I guess would be to extract the relevant parts from the Llava model

My code actually already does it and it only load the vision tower and projector. What changes did you make/what made you load the full model?

turboderp commented 2 months ago

I'm not sure exactly where it deviates from your example. I had some trouble getting it to run, and I wasn't quite sure I was using the right models, so instead I just started from the example code snippet in the llava-v1.6-mistral readme, mostly to familiarize myself with how it's supposed to work in HF.

turboderp commented 2 months ago

I did a little more experimenting, and I can get your version to run (unreliably) if I use a prompt of [INST] {{EMBED_HERE}}\nWhat is shown in this image? [/INST] and another model than the one you provided. Possibly there's something off with it, and I do notice that according to the config it's actually Llama-7B, not Mistral?

Anyway, the reason it's unreliable seems to be due to the input embeddings. Here's my version for the image you linked to:

tensor([[[ 3.2410e-02, -3.6694e-01,  2.2559e-01,  ..., -4.4250e-02,
           1.3599e-01, -9.8694e-02],
         [ 1.6113e-01, -4.9683e-01,  2.9346e-01,  ..., -1.6235e-01,
          -6.9122e-03,  1.5002e-01],
         [ 1.1218e-01, -1.7859e-01,  3.0054e-01,  ..., -1.1554e-01,
          -1.2415e-01, -4.7882e-02],
         ...,
         [ 1.6736e-01, -5.2002e-01,  2.5586e-01,  ..., -9.2224e-02,
           2.9016e-04,  2.3267e-01],
         [-1.3657e-02, -5.7471e-01,  2.3413e-01,  ...,  4.4708e-02,
           1.1053e-01,  1.9638e-02],
         [-1.6602e-02, -6.8359e-03,  1.2283e-03,  ...,  5.9509e-03,
           1.8433e-02,  7.5531e-04]]], device='cuda:0', dtype=torch.float16,
       grad_fn=<StackBackward0>)

[INST] 
What is shown in this image? [/INST] The image shows a diagram of the performance and efficiency of various machine learning models, represented as a red and orange triangle (MM-Bench), a green hexagon (VizWiz), and a blue pentagon (TextVa-OwC-N). The numbers inside each shape represent the accuracy and time taken for that specific model to process data. The arrows show the direction of improvement from one model to another. The goal is to find the optimal trade-off between accuracy and time taken for different models. In this case, the best model according to the given information is TextVa-OwC-N as it achieves both high accuracy and short processing time.

Here is the reference from the HF version:

tensor([[[ 3.2410e-02, -3.6694e-01,  2.2559e-01,  ..., -4.4250e-02,
           1.3599e-01, -9.8694e-02],
         [ 1.6113e-01, -4.9683e-01,  2.9346e-01,  ..., -1.6235e-01,
          -6.9122e-03,  1.5002e-01],
         [ 1.1218e-01, -1.7859e-01,  3.0054e-01,  ..., -1.1554e-01,
          -1.2415e-01, -4.7882e-02],
         ...,
         [ 1.6736e-01, -5.2002e-01,  2.5586e-01,  ..., -9.2224e-02,
           2.9016e-04,  2.3267e-01],
         [-1.3657e-02, -5.7471e-01,  2.3413e-01,  ...,  4.4708e-02,
           1.1053e-01,  1.9638e-02],
         [-1.6602e-02, -6.8359e-03,  1.2283e-03,  ...,  5.9509e-03,
           1.8433e-02,  7.5531e-04]]], device='cuda:0', dtype=torch.float16)

[INST]  
What is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot that displays values for multiple quantitative variables represented on axes starting from the same point. This particular radar chart is showing the performance of different models or systems across various metrics.

The axes represent different metrics or benchmarks, such as MM-Vet, MM-Vet, MM-Vet, MM-Vet, MM-Vet, MM-Vet, MM-Vet, MM-Vet, MM-Vet, MM-Vet, MM-Vet, MM-Vet, MM-Vet, MM-Vet

Here's yours

tensor([[[ 3.2410e-02, -3.6694e-01,  2.2559e-01,  ..., -4.4250e-02,
           1.3599e-01, -9.8694e-02],
         [ 1.6113e-01, -4.9683e-01,  2.9346e-01,  ..., -1.6235e-01,
          -6.9122e-03,  1.5002e-01],
         [ 1.1218e-01, -1.7859e-01,  3.0054e-01,  ..., -1.1554e-01,
          -1.2415e-01, -4.7882e-02],
         ...,
         [ 1.6736e-01, -5.2002e-01,  2.5586e-01,  ..., -9.2224e-02,
           2.9016e-04,  2.3267e-01],
         [-1.3657e-02, -5.7471e-01,  2.3413e-01,  ...,  4.4708e-02,
           1.1053e-01,  1.9638e-02],
         [ 1.6602e-02, -1.2494e-01,         nan,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00]]], device='cuda:0', dtype=torch.float16,
       grad_fn=<StackBackward0>)

[INST] 
What is shown in this image? [/INST]
m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m/m

The last row of the image_features tensor is different, possibly unitialized due to an off-by-one error. Could be where you call LlavaNextProcessor.image_processor with just the image data, and I'm calling LlavaNextProcessor with the image data and an empty prompt.

aliencaocao commented 2 months ago

according to the config it's actually Llama-7B, not Mistral

This is fine because they are basically the exact same for v0.2. It was different impl for v0.1 due to sliding window atten but 0.2 removed it so now its exact same as llama except for the higher ctx of 32768 and rope theta of 1000000 which r both in config and not relevant to python codes.

Could be where you call LlavaNextProcessor.image_processor with just the image data, and I'm calling LlavaNextProcessor with the image data and an empty prompt.

Not the reason here as even if I call the full preprocessor it still gives gibberish and last line still differs.

The last line of the image tensor being different should be because image_newline is being initiated as image_newline = nn.Parameter(torch.empty(HF_config.text_config.hidden_size, dtype=torch.float16)).to(device) in my code whereas in yours it is from the loaded (full) model.

I also noticed this because using your code + my own models with the image newline initialized using my way, your provided examples fail very hard too.

What am I confused about is that, it was literally how HF initialized this: https://github.com/huggingface/transformers/blob/76fa17c1663a0efeca7208c20579833365584889/src/transformers/models/llava_next/modeling_llava_next.py#L310

Unless I have some misunderstanding about how torch.empty works, it is supposed to be random in the first place because it just takes a chunk of memory and use whatever inside it as the value. Turns out, if you load the full model in HF it is not random. I dont get this part yet. My PC ooms if I try to load both the fp16 and the exllama so i cant really test...

turboderp commented 2 months ago

You can stick this before loading the EXL2 model in my version:

# Unload HF Llava model

import gc
del llavamodel
del preprocessor
del base_image_feature
del image
del image_feature
del new_image_features
del pixel_values
del reshaped_pixel_values
del selected_image_feature
del inputs

torch.cuda.synchronize()
gc.collect()
torch.cuda.empty_cache()

And maybe add:

config.max_seq_len = 4096

As for torch.empty, it does give you unallocated/undefined memory, but any nn.Parameter defined in a module becomes part of that module's state dictionary. So it's trainable, and there's a corresponding tensor in the model weights that should be loaded.

aliencaocao commented 2 months ago

Thanks for the help, I figured I just had to get the original image_newline embedding from language model weights, save it separately then load it:

import safetensors
from safetensors.torch import save_file
import torch
from transformers import LlavaNextForConditionalGeneration, LlavaNextProcessor

model = LlavaNextForConditionalGeneration.from_pretrained('models/llava-v1.6-mistral-7b-hf', low_cpu_mem_usage=True, torch_dtype=torch.float16)

newline = model.image_newline.data
save_file({'image_newline': newline}, MODEL_NAME + '-CLIP/newline.safetensors')

# in exllama test script:
newline = safetensors.safe_open('models/llava-v1.6-mistral-7b-hf-CLIP/newline.safetensors', framework='pt').get_slice('image_newline')[:]
image_newline = nn.Parameter(newline, requires_grad=False).to(device)

With a simple fix like this, my original example works like a charm:


[INST] 
What is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional graph that displays data on two or more quantitative variables represented on axes starting from the same point. This particular radar chart is showing performance metrics for different models or systems, likely related to machine learning or artificial intelligence.

The axes represent various evaluation metrics such as MM-Vet, MM-Vet, GQA, VizWiz, SQuAD, and others. Each model or system is represented by a line that intersects at different points along these axes, indicating its performance across each metric.

The colors and labels on the right side of the image suggest that there are different versions or configurations of the models being compared, such as BLIP-2, InstructionBLIP, and others. The numbers along the axes represent specific scores or values for each model on each metric.

Without additional context, it's difficult to provide a detailed interpretation of the data, but generally, this kind of chart is used to compare the performance of different models or systems across multiple tasks or datasets. ```
aliencaocao commented 2 months ago

An unrelated question: is there a convenient switch to not return the prompt in output tokens and logits? I am totally fine with removing it myself but just asking in case I'm doing extra work.

EDIT: i opened https://github.com/turboderp/exllamav2/pull/402 for this

turboderp commented 2 months ago

Oh, I just added a completion_only option. I'll look at the PR in a bit, but I have to bump to 0.0.18 and release today.

aliencaocao commented 2 months ago

Do you think theres a need to run the calibration for quanting using image-text dataset? For the LLM part of llava models, or will the default wikitext work? From my own tries, i get some significant perf loss at 6bpw on a finetuned llava 7b

turboderp commented 2 months ago

The default isn't wikitext, it just contains wikitext. There's also a lot of multilingual data, scrambled text and random tokens. The quantization is definitely optimized for text, and it's unclear how the image tokens end up aligning with that. I would have to imagine the image tokens end up aligning somewhat since they ultimately have to be interpreted by the model without any finetuning, but I really don't have the data to say anything for sure.

Increasing the damping here will lessen the overall effect of calibration, so that's something to try if you're not getting the performance you'd expect. 6 bpw isn't going to be especially affected to begin with, though. Calibration isn't finetuning, after all. It's much more like "smart rounding", and the rounding error is quite small at 6 bpw.

aliencaocao commented 2 months ago

I did my own calibration image-text dataset and the result is very good: i am getting 3% improvement in metrics compared to fp16, while using the default calibration data I get 20% drop. Seems to play a huge role here.

I used 20 rows of 4096 tokens of embedding

turboderp commented 2 months ago

I guess it would make sense to compile dedicated Llava models with vision tower included, along with weights calibrated for the image tokens. Did you do any how image embeddings generalize? I.e. if you calibrate the quantization using drawings, would it still perform well enough on photos?

aliencaocao commented 2 months ago

In any current quant implementation including HF's BNB, the vision tower and the projector was kept in fp16 (weights are in fp32 but loaded as fp16). In my test, after warming up once, it takes less than 0.01 seconds to perform a forward pass on a relatively large image (1000x1000) on a V100. Thus, i dont think it is neccessary, and i think that there can be significant accuracy loss if you do so, thanks to its smaller param size. As for calibration, our specific task was image captioning + classification of memes, so i guess its more towards document understanding + photos. Our calibration set are also memes + text pairs. We did not try how generalizable is the calibration on other images like drawings. I can make a PR to add support of inputting arbitary .safetensors file as the calibration embedding instead of using the built in data. These embeddings has to be first generated and stored by a fp16 model like HF impl. For the same project, I also implemented batching for embedding inputs but the performance does take a hit (not a lot) due to lack of position offsets for image padding tokens that are in the middle of the input ids. The current impl only allows for 1 position offset which is counted from the back of tokenizer output.

xonfour commented 1 month ago

Just to be sure: There is currently no way of using this in generator.begin_stream_en(), correct?

turboderp commented 1 month ago

I added it to the streaming generator and pushed to the dev branch just now. I still want to add a more complete interface for it, with functions for ingesting images and support for multiple sets of embeddings in the same prompt, and stuff.

Currently, the generator removes the image IDs from the token sequence after ingesting the prompt, since otherwise subsequent generations with a new prompt might reuse (some of) the precomputed keys/values from a previous image. This means it will only reuse keys/values up to the first embedding token. I guess the generator will need to keep a hash of the embedding tensor for cache reuse to work properly, and that will tie into the rest of the multimodal features, somehow, in the future.

Anyway, you'll need to construct the input IDs with the indexed embeddings, something like this:

# Encode prompt

from exllamav2.embedding import EMBEDDING_INDEX

prompt = "[INST] {{EMBED_HERE}}\nWhat is shown in this image? [/INST] The image shows"
prompt_split = prompt.split("{{EMBED_HERE}}")

pre_ids = tokenizer.encode(prompt_split[0].rstrip(" \t"), encode_special_tokens = True, add_bos = True)
post_ids = tokenizer.encode(prompt_split[1].lstrip(" \t"), encode_special_tokens = True, add_bos = False)

num_emb_tokens = image_features.shape[1]
image_ids = torch.arange(EMBEDDING_INDEX, EMBEDDING_INDEX + num_emb_tokens, dtype = torch.long).unsqueeze(0)
ids = torch.cat((pre_ids, image_ids, post_ids), dim = -1)

# Generate

generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer)
generator.set_stop_conditions([tokenizer.eos_token_id])
generator.begin_stream_ex(ids, gen_settings, input_embeddings = image_features)

remaining_tokens = 100
while rem_tokens:
    res = generator.stream_ex()
    if res["eos"]: break
    print(res["chunk"], end = "")
    sys.stdout.flush()
    remaining_tokens -= 1
xonfour commented 1 month ago

Basically it Works! So far setting up and testing different vision models to find a good combo remains a rocky road. I like using Nous Capybara and have been looking for compatible models to get image features, but haven't found anything good so far (and most probably there is nothing available). I've tried with LinearAdapter to make it fit whenever necessary, but even then nothing usable came out... ok, no big surprise... ;-)

Anyway, thanks for providing! I will continue experimenting...