ContextualAI / gritlm

Generative Representational Instruction Tuning
https://arxiv.org/abs/2402.09906
MIT License
479 stars 33 forks source link

bidirectional attention or casual attention for embedding? #15

Open yonxie opened 3 months ago

yonxie commented 3 months ago

You mention that bidirectional attention is used for embedding task. But it appears that you only use the last hidden states from the pretrained LLM to generate embeddings. Is the final projection is the only bidirectional part?

Muennighoff commented 3 months ago

The last hidden state is produced via bidirectional attention in the model itself

Hisarlik commented 2 months ago

Hi, I'm currently trying to train gritlm using Gemma2b to generate embeddings. While reviewing the training script for Mistral7b, I noticed the use of bidirectional attention with attn='bbcc'. In the context of embeddings, would it be more advantageous to train with 'bbcc' or 'cccc'?

However, when I tried to use attn='bbcc' with Gemma, I encountered an error: TypeError: GemmaModel.forward() received an unexpected keyword argument 'is_causal'. To fix this, I commented out the following line in gritlm.py:

if (self.attn is not None) and (self.attn[:2] == 'bb'): inputs["is_causal"] = False

is this correct ?

Muennighoff commented 2 months ago

bbcc is better & commenting out that line will make it equivalent to cccc so it's not a good idea, also see https://github.com/ContextualAI/gritlm/issues/24

Vincent-Li-9701 commented 2 months ago

Hi @Muennighoff, amazing work! I have a similar confusing as @yonxie. I can see here that you did a final pooling. You mentioned that "The last hidden state is produced via bidirectional attention in the model itself". Would you mind pointing out where this is done?

I was also looking at the query-doc cacheing example at page 63. In order to reuse the key-value cache (if I understand correctly the key values are producing during forward pass using bidirectional attention), that means GRIT GRITLM functions as a prefixLM with two independent prefixes during RAG?

Muennighoff commented 2 months ago

Sorry for the confusion. I mean that inside of the model bidirectional attention is applied in every transformer layer. The attention mask for that is created here https://github.com/ContextualAI/gritlm/blob/47b7fe6c7109ba46b82b68c37d32aa9a8bf010c5/scripts/modeling_mistral_gritlm.py#L1018

The pooling that you point to is then applied to the final hidden state returned from the model to remove the sequence length dimension.

if I understand correctly the key values are producing during forward pass using bidirectional attention

Yes

that means GRIT GRITLM functions as a prefixLM with two independent prefixes during RAG?

The two caches (or prefixes if you will) are concatenated and have not paid attention to one another (maybe this is what you mean by independent). You may find it helpful to look at this code example: https://github.com/ContextualAI/gritlm?tab=readme-ov-file#caching