Closed OuYangg closed 1 year ago
Well, I suppose you didn`t not modify the replace token. To be specific, the replace token in Flan_based model is "图". As the image is encoded into a 32-bit visual prompt. So you need to replace the word"图" with a 32-bit replace token="".join([“图”]*32)
Well, I suppose you didn`t not modify the replace token. To be specific, the replace token in Flan_based model is "图". As the image is encoded into a 32-bit visual prompt. So you need to replace the word"图" with a 32-bit replace token="".join([“图”]*32)
Replacing '图' with a 32-bit replace token="".join([“图”]*32) can solve part of my problem, but I found some of MIC_sample's data still suffers from this problem. I compared the normal and error samples and didn't find a significant difference. The image can be decoded and input_text looks normal.
I use the following debug code to print shapes, and I found it seems like the number in input_ids is equal to sp_token < language_model_inputs.shape[0]*language_model_inputs.shape[1].
debug code:
def generate(inputs, model):
pixel_values = inputs['pixel_values'][inputs['img_mask'].bool()]
vision_feat = model.vision_model(pixel_values=pixel_values,
output_attentions=None,
output_hidden_states=None,
return_dict=None)[0]
print('image_embeds:',vision_feat.shape)
img_count = inputs['img_mask'].sum(1)
image_attention_mask = torch.ones(vision_feat.size()[:-1], dtype=torch.long, device=vision_feat.device)
print('image_attention_mask:', image_attention_mask.shape)
query_tokens = model.query_tokens.expand(vision_feat.shape[0], -1, -1)
print('query_tokens:', query_tokens.shape)
query_outputs = model.qformer(
query_embeds=query_tokens,
encoder_hidden_states=vision_feat,
encoder_attention_mask=image_attention_mask,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
)[0]
print('query_output:', query_outputs.shape)
language_model_inputs = model.language_projection(query_outputs)
print('language_model_inputs:', language_model_inputs.shape)
inputs_embeds = model.get_input_embeddings()(inputs['input_ids'])
print('inputs_embeds:', inputs_embeds.shape)
image_embeds_index = torch.where(inputs['input_ids'] == 32100)
print(image_embeds_index[1].shape)
inputs_embeds[image_embeds_index] = language_model_inputs.reshape(-1,language_model_inputs.shape[-1])
Example of error data:
input_text:
'Based on the image 0, give a caption about this image. Your caption should provide enough detail about image 0:
Example of norm data:
input_text:
'Your caption should provide sufficient information about image 0:
Thank you for reply!
Well, I suppose you didn`t not modify the replace token. To be specific, the replace token in Flan_based model is "图". As the image is encoded into a 32-bit visual prompt. So you need to replace the word"图" with a 32-bit replace token="".join([“图”]*32)
Replacing '图' with a 32-bit replace token="".join([“图”]*32) can solve part of my problem, but I found some of MIC_sample's data still suffers from this problem. I compared the normal and error samples and didn't find a significant difference. The image can be decoded and input_text looks normal.
I use the following debug code to print shapes, and I found it seems like the number in input_ids is equal to sp_token < language_model_inputs.shape[0]*language_model_inputs.shape[1].
debug code:
def generate(inputs, model): pixel_values = inputs['pixel_values'][inputs['img_mask'].bool()] vision_feat = model.vision_model(pixel_values=pixel_values, output_attentions=None, output_hidden_states=None, return_dict=None)[0] print('image_embeds:',vision_feat.shape) img_count = inputs['img_mask'].sum(1) image_attention_mask = torch.ones(vision_feat.size()[:-1], dtype=torch.long, device=vision_feat.device) print('image_attention_mask:', image_attention_mask.shape) query_tokens = model.query_tokens.expand(vision_feat.shape[0], -1, -1) print('query_tokens:', query_tokens.shape) query_outputs = model.qformer( query_embeds=query_tokens, encoder_hidden_states=vision_feat, encoder_attention_mask=image_attention_mask, output_attentions=None, output_hidden_states=None, return_dict=None, )[0] print('query_output:', query_outputs.shape) language_model_inputs = model.language_projection(query_outputs) print('language_model_inputs:', language_model_inputs.shape) inputs_embeds = model.get_input_embeddings()(inputs['input_ids']) print('inputs_embeds:', inputs_embeds.shape) image_embeds_index = torch.where(inputs['input_ids'] == 32100) print(image_embeds_index[1].shape) inputs_embeds[image_embeds_index] = language_model_inputs.reshape(-1,language_model_inputs.shape[-1])
Example of error data:
input_text: 'Based on the image 0, give a caption about this image. Your caption should provide enough detail about image 0: 图 to give the viewer a sense of what is happening in the image. A representation of a woman holding a surfboard on a sandy beach.\n\nBe creative in your approach to captioning image 1: 图 and try to convey a unique perspective or story. A red double decker London bus on the street\n\nBased on the image 2, give a caption about this image. Your caption should provide enough detail about image 2: 图 to give the viewer a sense of what is happening in the image. A boy doing a manual on a skateboard\n\nGive a caption about this image. Avoid using overly complex language or jargon in your caption of image 3: 图 that might confuse the viewer. A large cake shaped like two animal characters\n\nimage 4 is 图. Based on the image 4, describe what is contained in this photo. Your caption should be no more than a few sentences and should be grammatically correct and free of spelling errors. The computer desk has two laptops near the monitor.\n\nBe creative in your approach to captioning image 5: 图 and try to convey a unique perspective or story. A view of individuals at a park flying kites.\n\nCarefully analyze image 6: 图 to generate a concise and accurate description that accurately represents the objects, people, and scenery present. Some folks standing up holding some remotes together.\n\nBe creative in your approach to captioning image 7: 图 and try to convey a unique perspective or story.'
Example of norm data: input_text: 'Your caption should provide sufficient information about image 0: 图 so that someone who has not seen the image can understand it. A batter hitting the ball at a baseball game\n\nUse clear and concise language that accurately describes the content of image 1: 图. A man hitting a tennis ball with a tennis racquet.\n\nBased on the image 2, give a caption about this image. Your caption should provide enough detail about image 2: 图 to give the viewer a sense of what is happening in the image. a bedroom with a lamp and a closet\n\nBe creative in your approach to captioning image 3: 图 and try to convey a unique perspective or story. A crowd of individuals flying kites at a park.\n\nimage 4 is 图. Based on the image 4, describe what is contained in this photo. Your caption should be no more than a few sentences and should be grammatically correct and free of spelling errors.'
Thank you for reply!
This issue arises from the input context length. Despite flant5 using the relative position embedding, a max_token_length is still set for the tokenizer. It's evident that your input context surpasses the maximum length of 512. Consequently, the tokenizer truncated the input. This action caused the visual prompt length to mismatch with the replace token length. So, to fix this promble, you need to ensure that the input context you sent in the model should not truncate the image replace token.
Well, I suppose you didn`t not modify the replace token. To be specific, the replace token in Flan_based model is "图". As the image is encoded into a 32-bit visual prompt. So you need to replace the word"图" with a 32-bit replace token="".join([“图”]*32)
Replacing '图' with a 32-bit replace token="".join([“图”]32) can solve part of my problem, but I found some of MIC_sample's data still suffers from this problem. I compared the normal and error samples and didn't find a significant difference. The image can be decoded and input_text looks normal. I use the following debug code to print shapes, and I found it seems like the number in input_ids is equal to sp_token < language_model_inputs.shape[0]language_model_inputs.shape[1]. debug code:
def generate(inputs, model): pixel_values = inputs['pixel_values'][inputs['img_mask'].bool()] vision_feat = model.vision_model(pixel_values=pixel_values, output_attentions=None, output_hidden_states=None, return_dict=None)[0] print('image_embeds:',vision_feat.shape) img_count = inputs['img_mask'].sum(1) image_attention_mask = torch.ones(vision_feat.size()[:-1], dtype=torch.long, device=vision_feat.device) print('image_attention_mask:', image_attention_mask.shape) query_tokens = model.query_tokens.expand(vision_feat.shape[0], -1, -1) print('query_tokens:', query_tokens.shape) query_outputs = model.qformer( query_embeds=query_tokens, encoder_hidden_states=vision_feat, encoder_attention_mask=image_attention_mask, output_attentions=None, output_hidden_states=None, return_dict=None, )[0] print('query_output:', query_outputs.shape) language_model_inputs = model.language_projection(query_outputs) print('language_model_inputs:', language_model_inputs.shape) inputs_embeds = model.get_input_embeddings()(inputs['input_ids']) print('inputs_embeds:', inputs_embeds.shape) image_embeds_index = torch.where(inputs['input_ids'] == 32100) print(image_embeds_index[1].shape) inputs_embeds[image_embeds_index] = language_model_inputs.reshape(-1,language_model_inputs.shape[-1])
Example of error data: input_text: 'Based on the image 0, give a caption about this image. Your caption should provide enough detail about image 0: 图 to give the viewer a sense of what is happening in the image. A representation of a woman holding a surfboard on a sandy beach.\n\nBe creative in your approach to captioning image 1: 图 and try to convey a unique perspective or story. A red double decker London bus on the street\n\nBased on the image 2, give a caption about this image. Your caption should provide enough detail about image 2: 图 to give the viewer a sense of what is happening in the image. A boy doing a manual on a skateboard\n\nGive a caption about this image. Avoid using overly complex language or jargon in your caption of image 3: 图 that might confuse the viewer. A large cake shaped like two animal characters\n\nimage 4 is 图. Based on the image 4, describe what is contained in this photo. Your caption should be no more than a few sentences and should be grammatically correct and free of spelling errors. The computer desk has two laptops near the monitor.\n\nBe creative in your approach to captioning image 5: 图 and try to convey a unique perspective or story. A view of individuals at a park flying kites.\n\nCarefully analyze image 6: 图 to generate a concise and accurate description that accurately represents the objects, people, and scenery present. Some folks standing up holding some remotes together.\n\nBe creative in your approach to captioning image 7: 图 and try to convey a unique perspective or story.' Example of norm data: input_text: 'Your caption should provide sufficient information about image 0: 图 so that someone who has not seen the image can understand it. A batter hitting the ball at a baseball game\n\nUse clear and concise language that accurately describes the content of image 1: 图. A man hitting a tennis ball with a tennis racquet.\n\nBased on the image 2, give a caption about this image. Your caption should provide enough detail about image 2: 图 to give the viewer a sense of what is happening in the image. a bedroom with a lamp and a closet\n\nBe creative in your approach to captioning image 3: 图 and try to convey a unique perspective or story. A crowd of individuals flying kites at a park.\n\nimage 4 is 图. Based on the image 4, describe what is contained in this photo. Your caption should be no more than a few sentences and should be grammatically correct and free of spelling errors.' Thank you for reply!
This issue arises from the input context length. Despite flant5 using the relative position embedding, a max_token_length is still set for the tokenizer. It's evident that your input context surpasses the maximum length of 512. Consequently, the tokenizer truncated the input. This action caused the visual prompt length to mismatch with the replace token length. So, to fix this promble, you need to ensure that the input context you sent in the model should not truncate the image replace token.
Got it, thanks for the answer! One more small question: I only get all_results.json, best_results.json, train_results.json, trainer_state.json after I train the model normally, but I don’t find any weight files. I checked the code, it seems that the code only saves the results during training, how can I get weights?
Training log:
Have a nice weekend!
If you wish to save the checkpoint post-training, we suggest referring to the HuggingFace trainer and manually activating its save function. For instance, the _save_checkpoint function of the Trainer will allow you to save the model.
Got it! Thank you so much!
Hi, haozhe MIC is a nice work!
I‘ve tested MMICL-flan-t5-xxl on my own datasets and the test results are very promising. Therefore, I want to go one step further to finetune MMICL, but I got several problems during my finetuning process.
Questions: I followed the logic in data_preprocess.py to convert the flickr dataset (train.jsonl, test.jsonl, val.jsonl) in MIC_full to an .arrow file, but it reported an error during forward propagation. I suspect the problem is caused by data preprocessing.
data demo:
data preprocess script:
train script:
Best wishes.