ContextualAI / gritlm

Generative Representational Instruction Tuning
https://arxiv.org/abs/2402.09906
MIT License
479 stars 33 forks source link

attn attribute setting #34

Closed louieworth closed 1 month ago

louieworth commented 1 month ago
    attn: str = field(
        default='bbcc',
        metadata={
            "help": "bidirectional/causal attn for emb inst., emb sample, gen inst., gen sample"
                    " e.g. bbcc is bidirectional over both emb inst. & sample but causal over gen inst. & sample"
                    " cccc is causal over all; bccc is bidirectional over emb inst. but causal over rest etc."
        }
    )

I notice that attn=cccc is set to all scenarios ['unified', 'embedding', 'generative']. Is this right for ['unified', 'embedding'] tasks or,

Do we need to set attn=bbcc for ['unified', 'embedding'] in the encode function:

    def encode(self, features):
        if features is None: return None
        # Clone to avoid modifying the original tensor
        attention_mask = features['attention_mask'].clone() if 'attention_mask' in features else None
        instruction_lens = features['instruction_lens'] if 'instruction_lens' in features else None
        kwargs = {'input_ids': features.get('input_ids'), 'attention_mask': attention_mask}

        if self.attn[:2] == 'cb':
            kwargs['instruction_lens'] = instruction_lens
        elif self.attn[:2] == 'bb':
            kwargs['is_causal'] = False
        out = (getattr(self.model, self.embedding_attr) if self.embedding_attr else self.model)(**kwargs)[0]
Muennighoff commented 1 month ago

Yes it is always used if the first two letters are b it will use bidirectional attn for embedding via the code you pasted; if they are c it will use causal attention for embedding

louieworth commented 1 month ago

Thanks! How to determine what embedding way we should use? I think we should always use bbcc for embedding (embedding tasks require bidirectional attention). Please clarify me if I am wrong.

Muennighoff commented 1 month ago

Yes bbcc will always perform better. One small advantage of cccc is that if you intend to use the models for RAG with GRIT as described in the paper then there is no attention mismatch between bidir & causal, but it is probably not worth the performance drop from using cc instead of bb for embedding.

louieworth commented 1 month ago

Sorry, I think I am a little bit confused. If my training code is mode= unified, what is the suggestion for attn?

Muennighoff commented 1 month ago

Sorry, I think I am a little bit confused. If my training code is mode= unified, what is the suggestion for attn?

bbcc.

louieworth commented 1 month ago

Thanks! I still have three questions here:

  1. Just to confirm, the paper said that there are two things (Figure 3): a. prompt design for embedding tasks. b. uses bidirectional attention over the input for embedding tasks. Is there any other architectural modification other than, for embedding tasks, I need to specify attn=bbcc for:

    out = (getattr(self.model, self.embedding_attr) if self.embedding_attr else self.model)(**kwargs)[0]
  2. However, in the training script for Unified model (GRIT) the attn=cccc, is that a typo?

  3. I tried to use Lora for my custom training code, but I found that the following code:

    out = (getattr(self.model, self.embedding_attr) if self.embedding_attr else self.model)(**kwargs)[0]
    # embedding_attr = model
    # kwargs = dict_keys(['input_ids', 'attention_mask'])

    will output logits rather than the last_hidden_state. So, I modify it to:

    out = (getattr(self.model, embedding_attr) if embedding_attr else self.model)(**kwargs, output_hidden_states=True)
    out = out.hidden_states[-1]
    # embedding_attr = model
    # kwargs = dict_keys(['input_ids', 'attention_mask'])

    When I add is_causal=True to kwargs, it raises the bug.

    Exception has occurred: TypeError
    GPTNeoXForCausalLM.forward() got an unexpected keyword argument 'is_causal'

    will this cause any potential problem that will fail to do embedding tasks?

louieworth commented 1 month ago

I just reviewed this problem and found similar issues in #24 and #15 regarding attn=cccc. However, I find that they are specific to the mistral model, and how about other models e.g., pythia and llama?

How can I do the bidirectional attention with attn=bbcc?

Muennighoff commented 1 month ago

For other models, you need to add is_causal to their modeling code. You can see how it is done for Mistral here: https://github.com/ContextualAI/gritlm/blob/9883da1e77812e6ba2c107dc7b65d8c5ddc7396b/scripts/modeling_mistral_gritlm.py#L949