gregor-ge / mBLIP

MIT License
84 stars 7 forks source link

Training the bloomz model #14

Open bexxnaz opened 2 weeks ago

bexxnaz commented 2 weeks ago

Thank you again for your excellent work. I have trained a model mT0 using my own dataset, and it performs well. Now, I am attempting to train bloomz model, but I'm encountering an issue where the training loss does not decrease at all. Despite trying both the quantized and non-quantized versions. Could you please help me on how to address this issue?

gregor-ge commented 2 weeks ago

Hi,

I am happy to hear that my project is helpful for you.

Regarding your problem with Bloomz, I have unfortunately no good idea what the cause might be. The fact that training works fine with mT0 suggests to me that your setup is correct so Bloomz should work.

There are some config changes you have to make when switching from mT0 like setting datamodule.dataloader_cfg.collate_fn.padding_side="left" and module.model.random_init_projection=True but the first is only really important during evaluation and without the second, the training should crash right away anyway.

Loss not going down can be caused by wrong/bad hyperparameters (learning rates, weight decay). The values in the example config should work, so I do not expect this to be the cause for your problem but you could try increasing learning rate and see if loss goes down.

Sorry that I cannot be of more help.

bexxnaz commented 1 week ago

Hello, thank you very much for your quick and kind response. I trained the "BloomZ" model on my dataset using the "int4" quantized version, and the training loss decreased accordingly. However, I encountered an issue when testing the model with the weights, regardless of whether I set load_8bit=True or load_8bit="4bit". The output was garbage . What could be the problem? Thank you very much for your assistance.

processor = AutoProcessor.from_pretrained("Gregor/mblip-bloomz-7b") model = mBLIP( lm_pretrained="Gregor/mblip-bloomz-7b", random_init_projection=True, use_lora=True, train_checkpoint='/home/mBLIP/result_bloomz_w_llm/checkpoints/0-33960.ckpt', lora_checkpoint='/home/mBLIP/result_bloomz_wo_llm/checkpoints/0-33960' ) model = model.to('cuda') raw_image = Image.open('/home/OCO_train2014_000000231773.jpg').convert('RGB') question = "Summarize the image in Persian" inputs = processor(raw_image, question, return_tensors="pt").to("cuda") outputs = model.generate( **inputs, do_sample=True, max_length=128, min_length=1, num_beams=5, top_p=0.9 ) print(processor.decode(outputs[0], skip_special_tokens=True))

gregor-ge commented 1 week ago

I assume not all weights were loaded (correctly):

1) Don't forget to set blip_pretrained_checkpoint during evaluation. The weights for the image encoder are not loaded anywhere else otherwise.

2) This line (https://github.com/gregor-ge/mBLIP/blob/main/src/modules/modeling/mblip.py#L294) prints unexpected weights from train_checkpoint. If all is correct, this will just print []. Is this the case in your log?

bexxnaz commented 1 week ago

Yes, I set the "blip_pretrained_checkpoint" parameter in my code. The log also shows this []. I trained the mT0 model on my own dataset and tested it; it works well. I want to understand how the 4-bit quantization affects the model, even though the model's loss is decreasing significantly.

gregor-ge commented 1 week ago

Sorry, but then I am out of ideas. The 4bit or 8bit quantization should not affect training or evaluation - it never did for me.

If training works (loss goes down) but evaluation fails, then in my experience either 1) you forgot left-padding during evaluation (if you use batches > 1), or 2) there is some bug in the training code or a mismatch in model configuration between training and evaluation. But this is very open-ended and can be anything really.

bexxnaz commented 1 week ago

For training, I used "++datamodule.dataloader_cfg.collate_fn.padding_side="left"". For evaluation, I utilized the model's processor. processor = AutoProcessor.from_pretrained("Gregor/mblip-bloomz-7b")

gregor-ge commented 1 week ago

That should be correct then. There has to be a bug somewhere, maybe caused by some updates in the used libraries like peft or transformers, or by something else. But it is still very strange that mT0 works fine but Bloomz causes troubles.

bexxnaz commented 1 week ago

The problem was the tokenizer I used, which is 'bigscience/bloomz-7b1'. It is different from the auto processor. Even with these changes, it is still not working well.