Open yonxie opened 3 months ago
The last hidden state is produced via bidirectional attention in the model itself
Hi, I'm currently trying to train gritlm using Gemma2b to generate embeddings. While reviewing the training script for Mistral7b, I noticed the use of bidirectional attention with attn='bbcc'. In the context of embeddings, would it be more advantageous to train with 'bbcc' or 'cccc'?
However, when I tried to use attn='bbcc' with Gemma, I encountered an error: TypeError: GemmaModel.forward() received an unexpected keyword argument 'is_causal'. To fix this, I commented out the following line in gritlm.py:
if (self.attn is not None) and (self.attn[:2] == 'bb'): inputs["is_causal"] = False
is this correct ?
bbcc
is better & commenting out that line will make it equivalent to cccc
so it's not a good idea, also see https://github.com/ContextualAI/gritlm/issues/24
Hi @Muennighoff, amazing work! I have a similar confusing as @yonxie. I can see here that you did a final pooling. You mentioned that "The last hidden state is produced via bidirectional attention in the model itself". Would you mind pointing out where this is done?
I was also looking at the query-doc cacheing example at page 63. In order to reuse the key-value cache (if I understand correctly the key values are producing during forward pass using bidirectional attention), that means GRIT GRITLM functions as a prefixLM with two independent prefixes during RAG?
Sorry for the confusion. I mean that inside of the model bidirectional attention is applied in every transformer layer. The attention mask for that is created here https://github.com/ContextualAI/gritlm/blob/47b7fe6c7109ba46b82b68c37d32aa9a8bf010c5/scripts/modeling_mistral_gritlm.py#L1018
The pooling that you point to is then applied to the final hidden state returned from the model to remove the sequence length dimension.
if I understand correctly the key values are producing during forward pass using bidirectional attention
Yes
that means GRIT GRITLM functions as a prefixLM with two independent prefixes during RAG?
The two caches (or prefixes if you will) are concatenated and have not paid attention to one another (maybe this is what you mean by independent). You may find it helpful to look at this code example: https://github.com/ContextualAI/gritlm?tab=readme-ov-file#caching
You mention that bidirectional attention is used for embedding task. But it appears that you only use the last hidden states from the pretrained LLM to generate embeddings. Is the final projection is the only bidirectional part?