gregor-ge / mBLIP

MIT License
87 stars 7 forks source link

How did you change the LLM from BLIP2 with mT0-XL #5

Open joaopedrosdmm opened 1 year ago

joaopedrosdmm commented 1 year ago

Basically, that is my question. How did you re-align/replace Blip2's LLM (flan-xl from what I have seen) with mT0-XL?

joaopedrosdmm commented 1 year ago

I saw your training explanation I will try it, but if you could provide more guidance on how I can use new LLMs with BLIP2 that would be amazing as I'm training exactly that.

For instance, I would like to use blip2-flanXXL instead of flanXl, which bins would you need to download? How did you figure out that the first bin for blip2-flan-xl was the vit and qformer?

I'm sure you already took a look at Minigpt-4, so you know that if you train a progression layer with already pre-trained vit + qformer and add a new LLM you get pretty good results. I saw here and in your read.me the word Lora. Have you tried training a Lora for a qformer? And how does it compare with Minigpt which trains only a progression layer?

gregor-ge commented 1 year ago

Hi,

happy to hear that you're interested in mBLIP.

How did you re-align/replace Blip2's LLM (flan-xl from what I have seen) with mT0-XL?

Did you already take a look at our paper (https://arxiv.org/abs/2307.06930)? In short, we initialize the ViT & Q-Former from a BLIP2 checkpoint (namely the flan-t5-xl one) and then first briefly train the linear projection and then the entire Q-Former and projection (and LoRA in the LLM) with prefix language modeling using our translated task mixture data.

I would like to use blip2-flanXXL instead of flanXl, which bins would you need to download? How did you figure out that the first bin for blip2-flan-xl was the vit and former?

Sharded HuggingFace checkpoints have an index (https://huggingface.co/Salesforce/blip2-flan-t5-xxl/blob/main/pytorch_model.bin.index.json) that maps where which weights are. Just take a look and download what's needed (likely the first bin).

Have you tried training a Lora for a qformer?

No, we did not try that. The reason for that is basically that full fine-tuning of the Q-Former will likely work best and the Q-Former itself is small enough (~100M parameters) that using LoRA is not necessary. Compared to the compute requirements from the LLM, the Q-Former is really negligible.

And how does it compare with Minigpt which trains only a progression layer?

We find that training the Q-Former along with the projection layer performs significantly better than training just the projection.

If you have any more specific questions, let me know.

joaopedrosdmm commented 1 year ago

Wow! You have done a fantastic job! And, thanks so much for the quick reply.

I will take a look at your paper, asap.

joaopedrosdmm commented 1 year ago

Hey,

I'm a little confused about how to load the models with the state_dict containing the weights retrieved from here:

from huggingface_hub import hf_hub_download

hf_hub_download(
    repo_id='Salesforce/blip2-flan-t5-xxl',
    filename='pytorch_model-00001-of-00006.bin',
    local_dir='models/Blip2NoLM',
    local_dir_use_symlinks=False
)

import torch

state_dict = torch.load("/kaggle/working/models/Blip2NoLM/pytorch_model-00001-of-00006.bin")

state_dict = {k:v for k,v in state_dict.items() if "language_model" not in k}

torch.save(state_dict, "blip2-flant5xxl-nolm.bin")

Wondering if you could provide me with some guidance.

I looked at your code in src/modules/modeling/mblip.py. It seems that you load the Blip2VisionModel and Blip2QFormerMode from the transformers pre-trained and then give it the respective config. Is that right? Where and how do you load the state_dic to the models?

Thanks in advance. I'm super interested in this. I'm trying to understand your code.

gregor-ge commented 1 year ago

In the experiment configs, there is a variable blip_checkpoint. This has to point to your blip2-flant5xxl-nolm.bin file.

This state_dict is then loaded into the ViT and Q-Former here: https://github.com/gregor-ge/mBLIP/blob/main/src/modules/modeling/mblip.py#L129 (and subsequent lines).

Does that answer your question?

joaopedrosdmm commented 1 year ago

Thanks. I believe it does.

joaopedrosdmm commented 1 year ago

I read Blip2 and mBlip's (yours) papers, and I have a few questions If you don't mind?

1) Why did you use a pretained Blip2 with Flant5-XXL (encoder/decoder instructor-tuned) instead of OPT (decoder, unsupervised learning)? I know about the benchmarks, and how Flan-t5 outperforms OPT in every measure, but that's normal since one is instructor-tuned and the other was trained in an unsupervised manner, so it's normal that the results would be significantly lower. For text generation, GPT (decoder only) performs much better than an encoder/decoder architecture. So why not use OPT?

Maybe I just didn't understand this paragraph in the blip2 paper:

We experiment with two types of LLMs: decoder-based LLMs and encoder-decoder-based LLMs. For decoder based LLMs, we pre-train with the language modeling loss, where the frozen LLM is tasked to generate the text conditioned on the visual representation from Q-Former. For encoder-decoder-based LLMs, we pre-train with the prefix language modeling loss, where we split a text into two parts. The prefix text is concatenated with the visual representation as input to the LLM’s encoder. The suffix text is used as the generation target for the LLM’s decoder.

I don't understand what they have done regarding the flant5 (encoder/decoder) training and how they calculated the loss

Why wouldn't it be better to use a prertained model with an LLM that would be similar to the one we desire to implement (decoder only), and even maybe with a similar loss strategy?

2) Why didn't you use 4-bit quantization on your LLM? I know you used 8-bit and as a default BFloat16 (with some LLMs) or FP16.

In the Blip paper they mention they tried 32bits a FP16 and there was no performance degradation. The are a lot of papers stating that performance degradation from 4-bit in many LLMs is quite small.

So, why didn't you use 4-bit quantization?

3) What your opinion of this paper: https://arxiv.org/pdf/2308.04152.pdf

The repo this: https://github.com/DCDmllm/Cheetah

You have a wonderful project here. Thanks so much for your contributions! I hope I don't bore you with my questions.

gregor-ge commented 1 year ago

I am happy to help with your questions.

Why did you use a pretained Blip2 with Flant5-XXL

I think you misunderstood mBLIP here: we used the Flan-T5-xl BLIP2 checkpoint for initialization but only for the ViT and Q-Former parts, we do not use the LLM (so Flan-T5); instead we use mT0 as the LLM part. There might be slight performance differences between choosing the different checkpoints for initialization, but in early experiments, we did not see big changes, so we simply stuck with the Flan-T5 checkpoint.

I don't understand what they have done regarding the flant5 (encoder/decoder) training and how they calculated the loss

Both pure decoder and encoder-decoder are trained the same way with language modeling: the decoder predicts the next tokens and the loss is then the cross entropy loss with the token logits. Important for encoder-decoders (but also for pure decoders) when you use a prefix like we do with the instructions is to mask those tokens for the loss.

Why didn't you use 4-bit quantization on your LLM

4-bit quantization is slower so if 8-bit is enough, you should use it. We do have an update in the work though and have trained a mBLIP model with BLOOMZ-7B and there we had to use 4bit to make it fit. So there is nothing fundamentally stopping you from using 4bit.

What your opinion of this paper: https://arxiv.org/pdf/2308.04152.pdf

I have not seen that paper yet so thanks for pointing it out. Models like this and the recent OpenFlamingo that can handle interleaved input are interesting because they enable few-shot prompting or new tasks that single-image models cant do.

joaopedrosdmm commented 1 year ago

Thank you so much for your answer!

I think you misunderstood mBLIP here: we used the Flan-T5-xl BLIP2 checkpoint for initialization but only for the ViT and Q-Former parts, we do not use the LLM (so Flan-T5); instead we use mT0 as the LLM part. There might be slight performance differences between choosing the different checkpoints for initialization, but in early experiments, we did not see big changes, so we simply stuck with the Flan-T5 checkpoint.

I know you don't actually use the Flan-T5. But right after you respond spot on to the idea behind my questions.

When you do take a look at that paper, let me know your thoughts.

What do you mean 4 bit quant. is slower? At inference time?

Btw when will you release (more or less) the new updates?

gregor-ge commented 1 year ago

Btw when will you release (more or less) the new updates?

Probably end of the month.

joaopedrosdmm commented 1 year ago

Hey again,

I don't know if you took a look at the Qwen-VL repo/paper, but in their repo is a response to an issue that is intriguing.

It states their performance gains are due to the following:

  1. Higher resolution is important for more detailed information extraction, especially for content with dense text and small fonts.
  2. There are many factors for us to achieve better performance at smaller resolutions, such as the basic LLM model, training method, training data, etc.
  3. We use a query quantity of 256, which can retain more information than 32. I also try a longer length such as 1024, which will cause a certain degree of convergence difficulty

I read the BLIP2 paper where they stated that they chose 32 query tokens because it "filtered" out irrelevant information, and here is a statement that is completely opposite.

Have you considered changing query tokens?

Could I change query tokens in mBlip and still get relevant information without additional training?

How would one go about training a model similar to mBlip with 256 (or wtv) query tokens? What would I need to change?

Additionally, how would we change the resolution for BLIP2? Would we need a major retraining?

Again, thanks so much for your time!

Edit: Btw, why didn't you started with a pre-trained instructBlip? Even though it is instruct-tuned.

joaopedrosdmm commented 1 year ago

I just read a bit of their paper. About the resolution, this is what they did (on Multi-task Pre-training):

We increase the input resolution of the visual encoder from 224 × 224 to 448 × 448, reducing the information loss caused by image down-sampling. We unlocked the large language model and trained the whole model.

This is very weird, how can a single-layer cross-attention perform better than a q-former?

Their datasets aren't that unique either. Maybe is the amount of data on which the model is trained.

gregor-ge commented 1 year ago

Hi,

before I get to your questions, I want to point out that Qwen uses likely many A100s with 40 or 80GB VRAM while we use 3090s with 24GB VRAM. That really helps with higher resolution or longer sequence lengths.

Have you considered changing query tokens? Could I change query tokens in mBlip and still get relevant information without additional training? How would one go about training a model similar to mBlip with 256 (or wtv) query tokens? What would I need to change?

We have not tried that but I would not be suprised if more tokens help. The sequence length in the LLM was one limiting factor for us (capping at 128 for training) so adding more Q-Former tokens was not considered.

You can have more query tokens by simply changing the config value for it but you have to train them because they will be randomly initialized at the start.

As a side note, the ViT with 448 resolution gives 1024 tokens -> 25% compression with 256 query tokens; BLIP2 with 224 resolution gives 256 tokens with 32 query tokens -> 12.5% compression. So it makes sense that they have to increase the numer of query tokens a lot. The "comparable" number with BLIP2 would be 64 query tokens.

Additionally, how would we change the resolution for BLIP2? Would we need a major retraining?

You can have a higher image resolution by interpolating the position embedding of the ViT (https://github.com/salesforce/LAVIS/blob/main/lavis/models/eva_vit.py#L428). If you do this, you need to also train the ViT, too, or it will give you bad results. We tried this briefly but this made training a lot more expensive to run and we only got minor improvements (because we do not have many OCR tasks which profit most from the higher res).

Edit: Btw, why didn't you started with a pre-trained instructBlip? Even though it is instruct-tuned.

InstructBLIP inputs the instruction text into the Q-Former which does not work with multilingual input. We discuss this in 3.1 in our paper.

This is very weird, how can a single-layer cross-attention perform better than a q-former? Their datasets aren't that unique either. Maybe is the amount of data on which the model is trained.

There are too many confounding factors between Qwen and BLIP2 to say that the single-layer CA beats the q-former: x10 more training data, bigger ViT, training the ViT along with the rest, increasing the image resolution, etc. I assume that in a fair and controlled comparison, the Q-former will work better.

joaopedrosdmm commented 1 year ago

Thanks so much for your reply. Very insightful!

joaopedrosdmm commented 1 year ago

Hey,

Sorry to bother you again, but I'm a little bit lost in your code.

So in src/modules/modeling/mblip.py you re-create the Blip2 with your new LLM. Where exactly is it called again for training? I guess it is in train.py. But since I've never used lighting or hydra, I'm a bit confused. How do you call the forward function there, does it use any image transformation before, any tokenizer?

Is this folder, src/tasks/vllm, just used for benchmark evaluation? Do I need to use any of the classes in, src/tasks/vllm/data.py, like LoadTransformImage, before training?

What I'm doing is something like this:

url = "http://images.cocodataset.org/val2017/000000039769.jpg"

image = Image.open(requests.get(url, stream=True).raw)
display(image.resize((254, 254)))

prompt = "Question: how many cats are there? Answer:"
inputs = processor(images=image, text=prompt, return_tensors="pt").to(device="cuda")

generated_ids = model.generate(**inputs)
print(generated_ids)

inputs = processor(images=image, text=prompt, return_tensors="pt").to(device="cuda")

outputs = model(**inputs)

print(outputs)

I got rubish results which is normal, but I haven't trained the model yet. But I'm wondering if I doing something wrong, or not understanding your approach correctly, if you could shed some light, it would be much appreciated.

gregor-ge commented 1 year ago

Hi,

those questions are totally understandable. The disadvantage of high-level frameworks is this "magical" instantiation of everything that makes understanding difficult.

What trident does in (very) short is instantiate everything (model, dataset, ...) with the parameters in the configs that lightning wants to do the training. If you see any "_target_: ..." in the configs, this tells trident (or rather hydra) to create this python object. For example, in the experiment configs, we have module.modeltarget: src.modules.modeling.mblip.mBLIPModule -> hydra will instantiate this object with the parameters there and hydra will then "give" this object to lightning as the model it should train/test. The forward function will be called by lightning during its training/testing loops.

For the tokenization/ image transformation: if you look into the configs/datamodule configs, you can see how we use our functions in src/tasks/data.py: We create datasets with the HuggingFace dataset library and use our tokenization mapping function to preprocess the text. The LoadTransformImage object is configured to be called when we iterate over the dataset, it loads the image with the given path and applies the transform. Finally, we have for the dataloader a collate function, that does some final data processing before the batch can be used as input for the forward/generate function of the model.

However, if you write your own code for inference (like in the example code), you do not need to use this - the code you have works just fine.

joaopedrosdmm commented 1 year ago

Thanks for the answers.

I have a few follow-up questions if you don't mind.

1)

What is the target2str for? I guess it is for getting the target into str, but what sort of input does it require? Why don't you just do this:

targets = [self.template.format(x) for x in examples[self.target_column]]

Instead of this:

targets = [x if self.target2str is None or not x else self.target2str[x] for x in examples[self.target_column]]

2)

In the Tokenzion function, I don't understand why you do this:

        if self.decoder_only and all(target for target in targets):  # have to mask the context for the loss
            for i in range(batch_size):
                sample_input_ids = model_inputs["input_ids"][i]
                label_input_ids = labels["input_ids"][i] + [self.tokenizer.eos_token_id]
                model_inputs["input_ids"][i] = sample_input_ids + label_input_ids
                labels["input_ids"][i] = [-100] * len(sample_input_ids) + label_input_ids
                model_inputs["attention_mask"][i] = [1] * len(model_inputs["input_ids"][i])
        model_inputs["labels"] = labels["input_ids"]

The comment suggests it is to mask the context for decoder-only LLMs, is it to join both label and context input ids to get the attention masks? Why add [-100] this must be to padding the context input ids, but I can't understand why. Does it have anything to do with a specific way to train Blip?

My tokenization looks something like this:

Why all attention masks are one?

{'input_ids': [[1, 14350, 263, 3273, 322, 1871, 1230, 6139, 393, 12141, 29879, 278, 7601, 17800, 322, 8820, 13920, 292, 297, 278, 2183, 1967, 29889, 1, 18776, 411, 286, 18813, 280, 6568, 292, 373, 17441, 310, 1250, 19852, 29892, 2038, 5962, 11565, 29889, 2], [1, 14350, 263, 3273, 322, 1871, 1230, 6139, 393, 12141, 29879, 278, 7601, 17800, 322, 8820, 13920, 292, 297, 278, 2183, 1967, 29889, 1, 319, 11203, 411, 372, 29915, 29879, 5076, 5764, 29892, 1791, 292, 372, 29915, 29879, 2343, 373, 278, 2625, 310, 263, 11565, 310, 4094, 29889, 2], [1, 20355, 915, 278, 1967, 297, 263, 2821, 322, 3022, 895, 8214, 29889, 1, 319, 2022, 6568, 292, 373, 2246, 310, 263, 7245, 29887, 482, 1559, 21299, 29889, 2], [1, 6204, 263, 3022, 895, 5777, 683, 393, 7913, 2486, 16612, 278, 1667, 3161, 297, 278, 1967, 4944, 29889, 1, 739, 338, 408, 1781, 263, 2058, 304, 9709, 408, 916, 10161, 310, 278, 4799, 637, 29889, 2]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], 'labels': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 18776, 411, 286, 18813, 280, 6568, 292, 373, 17441, 310, 1250, 19852, 29892, 2038, 5962, 11565, 29889, 2], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 319, 11203, 411, 372, 29915, 29879, 5076, 5764, 29892, 1791, 292, 372, 29915, 29879, 2343, 373, 278, 2625, 310, 263, 11565, 310, 4094, 29889, 2], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 319, 2022, 6568, 292, 373, 2246, 310, 263, 7245, 29887, 482, 1559, 21299, 29889, 2], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 739, 338, 408, 1781, 263, 2058, 304, 9709, 408, 916, 10161, 310, 278, 4799, 637, 29889, 2]]}
gregor-ge commented 1 year ago

What is the target2str for?

I wanted an easy way to re-map class labels in tasks like XVNLI (entailment, neutral, contradiction) to other labels like yes/no/maybe or true/false/maybe to see what works best without having to re-create the data files.

The input is a dictionary with keys as the labels as there are in the data file and values are the strings as used as input.

The comment suggests it is to mask the context for decoder-only LLMs

I guess "mask" here is a bit overloaded: this part does not mask the tokens for the attention but "masks" them for the language modeling loss in the labels: -100 is the value for labels that tells Pytorch to ignore this token for the loss computation. And the reason to mask those tokens is because they are part of the input so the model does not have to learn to produce them.