Closed linhaojia13 closed 2 months ago
# text prompt
prompt = 'Why is the image funny?'
text = f"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: <image>\n{prompt} ASSISTANT:"
text_chunks = [tokenizer(chunk).input_ids for chunk in text.split('<image>')]
input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1][1:], dtype=torch.long).unsqueeze(0).to(device)
When encoding the text, Bunny would use -200
to represent <image>
as the placeholder of image. Assuming that the length of input_ids
is L, it consists of L-1 normal tokens and one -200
.
After prepare_inputs_labels_for_multimodal
, the input_ids
would be converted to input_embeds
and then fed into the model. And right now, the image would take the actual length in input_embeds
. For SigLIP-SO, an image is encoded to 729 tokens. So the length of input_embeds
is L+728.
And this change of length would cause some mismatch of attention_mask and so on. We add
else:
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
to make the inference running normally. You can try what if deleting this part of code.
And actually, once LLaVA required transformers == 4.31
and at that time you can see this part of code in transformers
) (this commit of transformers
deleted this part of code). Bunny is consistent with LLaVA at that time.
And the release of LLaVA-1.6 overrides generate
function (see this commit) and then LLaVA can be compatible with latest version of transformers
. Bunny doesn't override generate
function and needs to modify the prepare_inputs_for_generation
function.
Thank you very much!
In the
transformers
codes, theprepare_inputs_for_generation
is as:However, in bunny/model/language_model/llama/modeling_llama.py, the
prepare_inputs_for_generation
is as:You add
When I debug the bunny codes, I find the
input_ids.shape[1]
is not the number of the generated token, which is different with llava's codes. From what I understand, Bunny's code framework is pretty much the same as llava's. So why is there such a big difference in input_ids during inference between your code and llava's that you need to change the source codes of the transformers?