HaozheZhao / MIC

MMICL, a state-of-the-art VLM with the in context learning ability from ICL, PKU
334 stars 15 forks source link

How to finetune MMICL on my own datasets #18

Closed OuYangg closed 1 year ago

OuYangg commented 1 year ago

Hi, haozhe MIC is a nice work!

I‘ve tested MMICL-flan-t5-xxl on my own datasets and the test results are very promising. Therefore, I want to go one step further to finetune MMICL, but I got several problems during my finetuning process.

Questions: I followed the logic in data_preprocess.py to convert the flickr dataset (train.jsonl, test.jsonl, val.jsonl) in MIC_full to an .arrow file, but it reported an error during forward propagation. I suspect the problem is caused by data preprocessing.

image

  1. Is there a problem with my data preprocessing?
  2. how to use MIC_sample data for training?

data demo:

{'input_text': 'Carefully analyze image 0: <image0>图 to generate a concise and accurate description that accurately represents the objects, people, and scenery present. A woman resting her bags.\n\nYour caption should provide sufficient information about image 1: <image1>图 so that someone who has not seen the image can understand it. The people are sitting on the ground with white tents surrounding them.\n\nBased on the image 2, give a caption about this image. Your caption should provide enough detail about image 2: <image2>图 to give the viewer a sense of what is happening in the image. A man with a blue shirt and red scarf on his head holds a piece of fruit in front of a wood container full of fruit while a boy bicycles behind him.\n\nimage 3 is <image3>图. Be specific and detailed in your description of image 3, but also try to capture the essence of the image 3 in a succinct way. A youth wearing a blue and red jersey and yellow helmet is crouching in a football position\n\nimage 4 is <image4>图. Based on the image 4, describe what is contained in this photo. Your caption should be no more than a few sentences and should be grammatically correct and free of spelling errors. Two teenagers are standing on the roof of a white building looking down.\n\nimage 5 is <image5>图. Be specific and detailed in your description of image 5, but also try to capture the essence of the image 5 in a succinct way. A singer sings in the evening at a nightclub.\n\nGive a caption about this image. Avoid using overly complex language or jargon in your caption of image 6: <image6>图 that might confuse the viewer.',
 'input_image': ['./data/flickr/flickr30k-images/4725675686.jpg',
  './data/flickr/flickr30k-images/3479245321.jpg',
  './data/flickr/flickr30k-images/3860415767.jpg',
  './data/flickr/flickr30k-images/3495455782.jpg',
  './data/flickr/flickr30k-images/2705103507.jpg',
  './data/flickr/flickr30k-images/243015377.jpg',
  './data/flickr/flickr30k-images/1000092795.jpg'],
 'output_text': 'Two young guys with shaggy hair look at their hands while hanging out in the yard.',
 'output_image': ''}

data preprocess script:

import h5py
import argparse
import threading
import numpy as np
from tqdm import tqdm
from PIL import Image
from os.path import join
from datasets import load_dataset
from datasets.arrow_writer import ArrowWriter
from model.instructblip import InstructBlipProcessor

processor = InstructBlipProcessor.from_pretrained('/workspace/data/hub/Salesforce:instructblip-flan-t5-xxl/')
sp = ['图']+[f"<image{i}>" for i in range(20)]
sp = sp+processor.tokenizer.additional_special_tokens[len(sp):]
processor.tokenizer.add_special_tokens({'additional_special_tokens':sp})
max_seq_length = min(512, processor.tokenizer.model_max_length)
def read_image(postfix, img_root, img_path):
    img_path = img_path.split('/')[-1]
    if postfix == 'png':
        image = Image.open(join(img_root,img_path))
    elif postfix == 'h5':
        image = h5py.File(join(img_root,img_path), 'r')
    else:
        image = Image.open(join(img_root, img_path))
    return image

def preprocess_function(input_text,input_image,output_text, processor, img_root):
    result = {}
    flag = isinstance(input_image,list)
    result["pixel_values"] = []
    if flag:
        postfix = input_image[0][1:].split('.')[-1]
        for img_path in input_image:
            img_path = img_path[1:] if img_path[0] == '.' and img_path[1] !='/' else img_path
            img = read_image(postfix, img_root, img_path)
            result["pixel_values"].append(processor(images = img)["pixel_values"][0])
    return result

def preprocess_function_batched(result,input_text,output_text, processor, max_seq_length):
    re = processor.tokenizer(input_text, padding='max_length', max_length=max_seq_length, truncation=True)
    re['input_ids'] = np.array(re['input_ids'],dtype=np.int32)
    re['attention_mask'] = np.array(re['attention_mask'],dtype=np.bool_)
    # result['label'] = np.array(processor.tokenizer(output_text, padding='max_length', max_length=32, truncation=True)["input_ids"],dtype=np.int32)
    out = processor.tokenizer(output_text, padding='max_length', max_length=128, truncation=True)
    result['label'] = np.array(out["input_ids"],dtype=np.int32)
    result['label_attention_mask'] = np.array(out["attention_mask"],dtype=np.bool_)
    result.update(re)
    return result

def save_to_arrow(path,temp):
    with ArrowWriter(path=path) as writer: 
        writer.write_batch(temp) 
        writer.finalize() 

def main():

    parser = argparse.ArgumentParser(add_help=True)
    parser.add_argument('--json_file', type=str)
    parser.add_argument('--save_path', type=str)
    parser.add_argument('--img_root', type=str)
    parser.add_argument('--dataset_name', type=str)
    args = parser.parse_args()
    os.makedirs(args.save_path, exist_ok=True)
    ds = load_dataset('json', data_files=args.json_file, streaming=True, split='train')
    save_arrow_data={'pixel_values':[], 'label':[], 'input_ids':[], 'attention_mask':[]}
    threads = []
    index_arrow = 0
    for idx, each in tqdm(enumerate(ds)):
        input_text = each['input_text']
        input_imgs = each['input_image']
        output_text = each['output_text']

        try:
            temp = preprocess_function(input_text,input_imgs,output_text, processor=processor, img_root=args.img_root)
        except Exception as e:
            print(e)
            continue
        if len(temp['pixel_values'])!=len(input_imgs):
            print(len(temp['pixel_values']), len(input_imgs))
        temp = preprocess_function_batched(temp, input_text, output_text, processor=processor, max_seq_length=max_seq_length)
        for each in save_arrow_data:
                save_arrow_data[each].append(temp[each])
        if idx %1000 == 0 and idx !=0:
            path = join(args.save_path, '%s_%d.arrow' % (args.dataset_name, index_arrow))
            t = threading.Thread(target=save_to_arrow, args=(path, save_arrow_data))
            threads.append(t)
            t.start()
            save_arrow_data={'pixel_values':[], 'label':[], 'input_ids':[], 'attention_mask':[]}
            index_arrow+=1
    for t in threads:
        t.join()
    path = join(args.save_path, '%s_%d.arrow' % (args.dataset_name, index_arrow))
    save_to_arrow(path,save_arrow_data)

if __name__ == '__main__':
    main()

train script:

export EXPERIMENT_NAME=instruct_BLIP_deepSpeed_t5xxl_unfreeze_Projection_LLM_QV_weight_without_instruct_qformer
export DATASET_NAME=flickr
export CUDA_VISIBLE_DEVICES=0
export MODEL_DIR=/result_dir/model/
model_name_or_path=/workspace/data/hub/MMICL-Instructblip-T5-xxl
processor_path=/workspace/data/hub/Salesforce:instructblip-flan-t5-xxl

bs=1
eval_bs=1
lr=1e-4
dropout=0.1
epoch=1
seed=1234
do_train=True
do_test=False
do_valid=False
master_port=29504
model_type=instructblip
deepspeed --master_port $master_port run.py \
--experiment_name ${EXPERIMENT_NAME} \
--dataset_name ${DATASET_NAME} \
--dataset_config_name None \
--load_datatype json \
--max_seq_length 512 \
--overwrite_cache True \
--pad_to_max_length True \
--train_file /workspace/data/mmlm_datasets/flickr30k/flickr_train \
--validation_file /workspace/data/mmlm_datasets/flickr30k/flickr_test \
--test_file /workspace/data/mmlm_datasets/flickr30k/flickr_val \
--do_train $do_train \
--do_eval $do_valid \
--do_predict $do_test \
--per_device_train_batch_size ${bs} \
--bf16 False \
--model_type $model_type \
--save_total_limit 3 \
--per_device_eval_batch_size ${eval_bs} \
--gradient_accumulation_steps 6 \
--num_train_epochs ${epoch} \
--output_dir checkpoints/${EXPERIMENT_NAME} \
--overwrite_output_dir \
--learning_rate ${lr} \
--weight_decay 0.0005 \
--seed ${seed} \
--warmup_ratio 0.2 \
--evaluation_strategy steps \
--eval_steps 100 \
--remove_unused_columns False \
--model_name_or_path $model_name_or_path \
--use_fast_tokenizer True \
--processor_path $processor_path \
--model_revision main \
--eval_type val \
--generation_max_length 64 \
--done_preprocess True \
--max_eval_samples 3500 \
--max_predict_samples 3500 \
--run_name ${EXPERIMENT_NAME} \
--using_instruct_qformer False \
--deepspeed config/deepspeed_config.json \
--load_best_model_at_end \

Best wishes.

HaozheZhao commented 1 year ago

Well, I suppose you didn`t not modify the replace token. To be specific, the replace token in Flan_based model is "图". As the image is encoded into a 32-bit visual prompt. So you need to replace the word"图" with a 32-bit replace token="".join([“图”]*32)

OuYangg commented 1 year ago

Well, I suppose you didn`t not modify the replace token. To be specific, the replace token in Flan_based model is "图". As the image is encoded into a 32-bit visual prompt. So you need to replace the word"图" with a 32-bit replace token="".join([“图”]*32)

Replacing '图' with a 32-bit replace token="".join([“图”]*32) can solve part of my problem, but I found some of MIC_sample's data still suffers from this problem. I compared the normal and error samples and didn't find a significant difference. The image can be decoded and input_text looks normal.

I use the following debug code to print shapes, and I found it seems like the number in input_ids is equal to sp_token < language_model_inputs.shape[0]*language_model_inputs.shape[1].

debug code:

def generate(inputs, model):
    pixel_values = inputs['pixel_values'][inputs['img_mask'].bool()]
    vision_feat = model.vision_model(pixel_values=pixel_values, 
                                 output_attentions=None,
                                output_hidden_states=None,
                                return_dict=None)[0]
    print('image_embeds:',vision_feat.shape)
    img_count = inputs['img_mask'].sum(1)
    image_attention_mask = torch.ones(vision_feat.size()[:-1], dtype=torch.long, device=vision_feat.device)
    print('image_attention_mask:', image_attention_mask.shape)
    query_tokens = model.query_tokens.expand(vision_feat.shape[0], -1, -1)
    print('query_tokens:', query_tokens.shape)
    query_outputs = model.qformer(
                        query_embeds=query_tokens,
                        encoder_hidden_states=vision_feat,
                        encoder_attention_mask=image_attention_mask,
                        output_attentions=None,
                        output_hidden_states=None,
                        return_dict=None,
                    )[0]
    print('query_output:', query_outputs.shape)
    language_model_inputs = model.language_projection(query_outputs)
    print('language_model_inputs:', language_model_inputs.shape)
    inputs_embeds = model.get_input_embeddings()(inputs['input_ids'])
    print('inputs_embeds:', inputs_embeds.shape)
    image_embeds_index = torch.where(inputs['input_ids'] == 32100)
    print(image_embeds_index[1].shape)
    inputs_embeds[image_embeds_index] = language_model_inputs.reshape(-1,language_model_inputs.shape[-1])

Example of error data: image

input_text: 'Based on the image 0, give a caption about this image. Your caption should provide enough detail about image 0: 图 to give the viewer a sense of what is happening in the image. A representation of a woman holding a surfboard on a sandy beach.\n\nBe creative in your approach to captioning image 1: 图 and try to convey a unique perspective or story. A red double decker London bus on the street\n\nBased on the image 2, give a caption about this image. Your caption should provide enough detail about image 2: 图 to give the viewer a sense of what is happening in the image. A boy doing a manual on a skateboard\n\nGive a caption about this image. Avoid using overly complex language or jargon in your caption of image 3: 图 that might confuse the viewer. A large cake shaped like two animal characters\n\nimage 4 is 图. Based on the image 4, describe what is contained in this photo. Your caption should be no more than a few sentences and should be grammatically correct and free of spelling errors. The computer desk has two laptops near the monitor.\n\nBe creative in your approach to captioning image 5: 图 and try to convey a unique perspective or story. A view of individuals at a park flying kites.\n\nCarefully analyze image 6: 图 to generate a concise and accurate description that accurately represents the objects, people, and scenery present. Some folks standing up holding some remotes together.\n\nBe creative in your approach to captioning image 7: 图 and try to convey a unique perspective or story.'

Example of norm data: image input_text: 'Your caption should provide sufficient information about image 0: 图 so that someone who has not seen the image can understand it. A batter hitting the ball at a baseball game\n\nUse clear and concise language that accurately describes the content of image 1: 图. A man hitting a tennis ball with a tennis racquet.\n\nBased on the image 2, give a caption about this image. Your caption should provide enough detail about image 2: 图 to give the viewer a sense of what is happening in the image. a bedroom with a lamp and a closet\n\nBe creative in your approach to captioning image 3: 图 and try to convey a unique perspective or story. A crowd of individuals flying kites at a park.\n\nimage 4 is 图. Based on the image 4, describe what is contained in this photo. Your caption should be no more than a few sentences and should be grammatically correct and free of spelling errors.'

Thank you for reply!

HaozheZhao commented 1 year ago

Well, I suppose you didn`t not modify the replace token. To be specific, the replace token in Flan_based model is "图". As the image is encoded into a 32-bit visual prompt. So you need to replace the word"图" with a 32-bit replace token="".join([“图”]*32)

Replacing '图' with a 32-bit replace token="".join([“图”]*32) can solve part of my problem, but I found some of MIC_sample's data still suffers from this problem. I compared the normal and error samples and didn't find a significant difference. The image can be decoded and input_text looks normal.

I use the following debug code to print shapes, and I found it seems like the number in input_ids is equal to sp_token < language_model_inputs.shape[0]*language_model_inputs.shape[1].

debug code:

def generate(inputs, model):
    pixel_values = inputs['pixel_values'][inputs['img_mask'].bool()]
    vision_feat = model.vision_model(pixel_values=pixel_values, 
                                 output_attentions=None,
                                output_hidden_states=None,
                                return_dict=None)[0]
    print('image_embeds:',vision_feat.shape)
    img_count = inputs['img_mask'].sum(1)
    image_attention_mask = torch.ones(vision_feat.size()[:-1], dtype=torch.long, device=vision_feat.device)
    print('image_attention_mask:', image_attention_mask.shape)
    query_tokens = model.query_tokens.expand(vision_feat.shape[0], -1, -1)
    print('query_tokens:', query_tokens.shape)
    query_outputs = model.qformer(
                        query_embeds=query_tokens,
                        encoder_hidden_states=vision_feat,
                        encoder_attention_mask=image_attention_mask,
                        output_attentions=None,
                        output_hidden_states=None,
                        return_dict=None,
                    )[0]
    print('query_output:', query_outputs.shape)
    language_model_inputs = model.language_projection(query_outputs)
    print('language_model_inputs:', language_model_inputs.shape)
    inputs_embeds = model.get_input_embeddings()(inputs['input_ids'])
    print('inputs_embeds:', inputs_embeds.shape)
    image_embeds_index = torch.where(inputs['input_ids'] == 32100)
    print(image_embeds_index[1].shape)
    inputs_embeds[image_embeds_index] = language_model_inputs.reshape(-1,language_model_inputs.shape[-1])

Example of error data: image

input_text: 'Based on the image 0, give a caption about this image. Your caption should provide enough detail about image 0: 图 to give the viewer a sense of what is happening in the image. A representation of a woman holding a surfboard on a sandy beach.\n\nBe creative in your approach to captioning image 1: 图 and try to convey a unique perspective or story. A red double decker London bus on the street\n\nBased on the image 2, give a caption about this image. Your caption should provide enough detail about image 2: 图 to give the viewer a sense of what is happening in the image. A boy doing a manual on a skateboard\n\nGive a caption about this image. Avoid using overly complex language or jargon in your caption of image 3: 图 that might confuse the viewer. A large cake shaped like two animal characters\n\nimage 4 is 图. Based on the image 4, describe what is contained in this photo. Your caption should be no more than a few sentences and should be grammatically correct and free of spelling errors. The computer desk has two laptops near the monitor.\n\nBe creative in your approach to captioning image 5: 图 and try to convey a unique perspective or story. A view of individuals at a park flying kites.\n\nCarefully analyze image 6: 图 to generate a concise and accurate description that accurately represents the objects, people, and scenery present. Some folks standing up holding some remotes together.\n\nBe creative in your approach to captioning image 7: 图 and try to convey a unique perspective or story.'

Example of norm data: image input_text: 'Your caption should provide sufficient information about image 0: 图 so that someone who has not seen the image can understand it. A batter hitting the ball at a baseball game\n\nUse clear and concise language that accurately describes the content of image 1: 图. A man hitting a tennis ball with a tennis racquet.\n\nBased on the image 2, give a caption about this image. Your caption should provide enough detail about image 2: 图 to give the viewer a sense of what is happening in the image. a bedroom with a lamp and a closet\n\nBe creative in your approach to captioning image 3: 图 and try to convey a unique perspective or story. A crowd of individuals flying kites at a park.\n\nimage 4 is 图. Based on the image 4, describe what is contained in this photo. Your caption should be no more than a few sentences and should be grammatically correct and free of spelling errors.'

Thank you for reply!

This issue arises from the input context length. Despite flant5 using the relative position embedding, a max_token_length is still set for the tokenizer. It's evident that your input context surpasses the maximum length of 512. Consequently, the tokenizer truncated the input. This action caused the visual prompt length to mismatch with the replace token length. So, to fix this promble, you need to ensure that the input context you sent in the model should not truncate the image replace token.

OuYangg commented 1 year ago

Well, I suppose you didn`t not modify the replace token. To be specific, the replace token in Flan_based model is "图". As the image is encoded into a 32-bit visual prompt. So you need to replace the word"图" with a 32-bit replace token="".join([“图”]*32)

Replacing '图' with a 32-bit replace token="".join([“图”]32) can solve part of my problem, but I found some of MIC_sample's data still suffers from this problem. I compared the normal and error samples and didn't find a significant difference. The image can be decoded and input_text looks normal. I use the following debug code to print shapes, and I found it seems like the number in input_ids is equal to sp_token < language_model_inputs.shape[0]language_model_inputs.shape[1]. debug code:

def generate(inputs, model):
    pixel_values = inputs['pixel_values'][inputs['img_mask'].bool()]
    vision_feat = model.vision_model(pixel_values=pixel_values, 
                                 output_attentions=None,
                                output_hidden_states=None,
                                return_dict=None)[0]
    print('image_embeds:',vision_feat.shape)
    img_count = inputs['img_mask'].sum(1)
    image_attention_mask = torch.ones(vision_feat.size()[:-1], dtype=torch.long, device=vision_feat.device)
    print('image_attention_mask:', image_attention_mask.shape)
    query_tokens = model.query_tokens.expand(vision_feat.shape[0], -1, -1)
    print('query_tokens:', query_tokens.shape)
    query_outputs = model.qformer(
                        query_embeds=query_tokens,
                        encoder_hidden_states=vision_feat,
                        encoder_attention_mask=image_attention_mask,
                        output_attentions=None,
                        output_hidden_states=None,
                        return_dict=None,
                    )[0]
    print('query_output:', query_outputs.shape)
    language_model_inputs = model.language_projection(query_outputs)
    print('language_model_inputs:', language_model_inputs.shape)
    inputs_embeds = model.get_input_embeddings()(inputs['input_ids'])
    print('inputs_embeds:', inputs_embeds.shape)
    image_embeds_index = torch.where(inputs['input_ids'] == 32100)
    print(image_embeds_index[1].shape)
    inputs_embeds[image_embeds_index] = language_model_inputs.reshape(-1,language_model_inputs.shape[-1])

Example of error data: image input_text: 'Based on the image 0, give a caption about this image. Your caption should provide enough detail about image 0: 图 to give the viewer a sense of what is happening in the image. A representation of a woman holding a surfboard on a sandy beach.\n\nBe creative in your approach to captioning image 1: 图 and try to convey a unique perspective or story. A red double decker London bus on the street\n\nBased on the image 2, give a caption about this image. Your caption should provide enough detail about image 2: 图 to give the viewer a sense of what is happening in the image. A boy doing a manual on a skateboard\n\nGive a caption about this image. Avoid using overly complex language or jargon in your caption of image 3: 图 that might confuse the viewer. A large cake shaped like two animal characters\n\nimage 4 is 图. Based on the image 4, describe what is contained in this photo. Your caption should be no more than a few sentences and should be grammatically correct and free of spelling errors. The computer desk has two laptops near the monitor.\n\nBe creative in your approach to captioning image 5: 图 and try to convey a unique perspective or story. A view of individuals at a park flying kites.\n\nCarefully analyze image 6: 图 to generate a concise and accurate description that accurately represents the objects, people, and scenery present. Some folks standing up holding some remotes together.\n\nBe creative in your approach to captioning image 7: 图 and try to convey a unique perspective or story.' Example of norm data: image input_text: 'Your caption should provide sufficient information about image 0: 图 so that someone who has not seen the image can understand it. A batter hitting the ball at a baseball game\n\nUse clear and concise language that accurately describes the content of image 1: 图. A man hitting a tennis ball with a tennis racquet.\n\nBased on the image 2, give a caption about this image. Your caption should provide enough detail about image 2: 图 to give the viewer a sense of what is happening in the image. a bedroom with a lamp and a closet\n\nBe creative in your approach to captioning image 3: 图 and try to convey a unique perspective or story. A crowd of individuals flying kites at a park.\n\nimage 4 is 图. Based on the image 4, describe what is contained in this photo. Your caption should be no more than a few sentences and should be grammatically correct and free of spelling errors.' Thank you for reply!

This issue arises from the input context length. Despite flant5 using the relative position embedding, a max_token_length is still set for the tokenizer. It's evident that your input context surpasses the maximum length of 512. Consequently, the tokenizer truncated the input. This action caused the visual prompt length to mismatch with the replace token length. So, to fix this promble, you need to ensure that the input context you sent in the model should not truncate the image replace token.

Got it, thanks for the answer! One more small question: I only get all_results.json, best_results.json, train_results.json, trainer_state.json after I train the model normally, but I don’t find any weight files. I checked the code, it seems that the code only saves the results during training, how can I get weights?

Training log:

72E47CAD-F3A8-4B2E-A398-426CCE7999EF

Have a nice weekend!

HaozheZhao commented 1 year ago

If you wish to save the checkpoint post-training, we suggest referring to the HuggingFace trainer and manually activating its save function. For instance, the _save_checkpoint function of the Trainer will allow you to save the model.

OuYangg commented 1 year ago

Got it! Thank you so much!