NJU-LHRS / LHRS-Bot

VGI-Enhanced multimodal large language model for remote sensing images.
Apache License 2.0
81 stars 7 forks source link

Input size of image Embedding and Token Embedding #16

Closed YongTaeIn closed 1 month ago

YongTaeIn commented 1 month ago

Hello~

Looking at the paper and Fig3, I understand that the result of Image Embedding and the Token Embedding value are concatenated and entered as input to LLaMA2.

I have a question here. In the rgb_visioni_modal.py file, I understand that EMBEDDING_DIM is converted to 768 or 1024 depending on whether the vit is base or large.

However, when I printed the input size of the LLAMA model, I found that it was 4096. If so, shouldn't the Token Embedded value be 4096-738 = 3328 (when vit_base), not 4096? Is there perhaps something I missed in the code? Or am I misunderstanding something?

pUmpKin-Co commented 1 month ago

Hi~

The hidden dimension of vision encoder will be project through bridge layer (vision perceiver) and align with the dimension of underline LLM (i.e., 4096). Have a look at here for details.

YongTaeIn commented 1 month ago

Thanks for reply~.

Hmm... What I didn't understand is that images and text are input to LLM during the Train process.

I know that LLama's input is 4096-dimensional, but if the number of image tokens is set to 4096 through the bridge layer (vision perceiver), isn't there no room left for text tokens in the LLama model?

(As far as I know, in the Train process, images and text are tokenized and then concatenated and entered as input to LLaMA.)

pUmpKin-Co commented 1 month ago

Usually, we should concat. along the sequence dimension.

For example. if you obtain vision tokens with shape [B, L, D'], where B, L, D refers to batchsize, sequence length, dimension. And text token with shape [B, L', D].

We first mapping D' -> D through bridge layer and concat vision token and text token in sequence dim, with the result of shape [B, (L + L'), D]

YongTaeIn commented 1 month ago

Thank you for a detailed description.

After all, then (L+L’) is the same as the number of input tokens of LLaMA2, right?

pUmpKin-Co commented 1 month ago

Exactly.