magic-research / PLLaVA

Official repository for the paper PLLaVA
593 stars 40 forks source link

train_pllava_13b.sh运行时会报错 #26

Open emmating12 opened 6 months ago

emmating12 commented 6 months ago

您好,我在两块40g A100上尝试训练pllava 13b模型时遇到了如下问题,请问下是什么原因呢?我把input_ids打印出来了

UserWarning: None of the inputs have requires_grad=True. Gradients will be None

ValueError: The input provided to the model are wrong. The number of image tokens is 16 while the number of image given to the model is 16. This prevents correct indexing and breaks batch generation.

img_v3_02an_14770f6b-cb92-4eee-a3bb-17cf95cfa9dg img_v3_02an_6e909f24-1968-4cbf-b9d0-b259b0709ecg img_v3_02an_4933416e-a474-4e15-a98a-d52bea6a599g

RE-N-Y commented 6 months ago

In _merge_input_ids_with_image_features

Under

image_to_overwrite = torch.all(final_embedding == 0, dim=-1)
image_to_overwrite &= image_to_overwrite.cumsum(-1) > nb_image_pad[:, None].to(target_device)

Add this line

image_to_overwrite = (image_to_overwrite.cumsum(-1) <=  num_special_image_tokens.max() * num_image_patches - nb_image_pad[:, None]) & image_to_overwrite
liuao743 commented 6 months ago

In _merge_input_ids_with_image_features

Under

image_to_overwrite = torch.all(final_embedding == 0, dim=-1)
image_to_overwrite &= image_to_overwrite.cumsum(-1) > nb_image_pad[:, None].to(target_device)

Add this line

image_to_overwrite = (image_to_overwrite.cumsum(-1) <=  num_special_image_tokens.max() * num_image_patches - nb_image_pad[:, None]) & image_to_overwrite

Hi, I encountered the same error. I followed your instructions to make the changes, but it still throws an error. ValueError: The input provided to the model are wrong. The number of image tokens is 0 while the number of image given to the model is 1. This prevents correct indexing and breaks batch generation. @RE-N-Y

RE-N-Y commented 6 months ago

So, I'm currently doing additional SFT on my dataset. Not sure about inference. But, to give you an insight, here's an explanation to why the error are happening.

Similar to other VLM implementations, PLLaVA first extract image features and tokenizes the sentence. So, your input ids look something like this (notice that the input has padding on the right)

[
   [ <Image>, Who, let, the, dogs, out, ?, <PAD>, <PAD>, <PAD>],
   [ <Image>, Who, let, the, cats, and, dogs, out ,?]
]

PLLaVA first calculates the final size of the final feature embedding and makes a zero tensor.

# 3. Create the full embedding, already padded to the maximum position
final_embedding = torch.zeros(
    batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
)
final_attention_mask = torch.zeros(
    batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
)

After that, it first populates it with text embeddings

# 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]

But, because the inputs have right padding. The embeddings at padded positions are zero. The problem is that they assume at this point that

  1. Any nonzero vector = populated with text vector
  2. Any zero vector = not populated yet = need to overwrite with image vector
image_to_overwrite = torch.all(final_embedding == 0, dim=-1)
image_to_overwrite &= image_to_overwrite.cumsum(-1) > nb_image_pad[:, None].to(target_device)

But, obviously because of right padding this assumption isn't true. That's why this code

image_to_overwrite = (image_to_overwrite.cumsum(-1) <=  num_special_image_tokens.max() * num_image_patches - nb_image_pad[:, None]) & image_to_overwrite

takes care of the job. This is actually commented out for some reason.

For inference, I'm assuming that the inputs might be left padded which the above line of code might not fix. My personal easy cheap trick would be to run inference with batch size 1 to avoid these issues overall. (Yes, I know this is inefficient.) Otherwise, you can probably try to figure out correct image_to_overwrite position yourself.