ContextualAI / gritlm

Generative Representational Instruction Tuning
https://arxiv.org/abs/2402.09906
MIT License
479 stars 33 forks source link

Changes needed within modeling_mistral.py for an embedding model #32

Closed spookyQubit closed 2 months ago

spookyQubit commented 2 months ago

Hi @Muennighoff , thanks a lot for making the code public. I had a question regarding training a mistral embedding model. Specifically, you note the following in the README:

For GritLM-7B and GritLM-8x7B, the folder contains a custom modeling file (modeling_gritlm*.py) which adds bidirectional attention via the keyword argument is_causal, such that if you load them with from_pretrained in transformers, it is automatically available.

This is of great help. But it would also be very helpful to know the exact changes you had to make to the transformers modeling_mistral.py file.

I tried doing a diff between the file in this repo vs the file in current transformers master. However, because the file in this repo must have been an edited version of an older transformer's file, it seems that there are multiple differences not related to is_causal.

When your time permits, can you please point out the diff/changes you made to the modeling_mistral.py file for gritlm? Or point out the exact version of the transformers file which you started from (on top of which you made your edits).

Thanks.

Muennighoff commented 2 months ago

Great point - I think it was this one: https://github.com/huggingface/transformers/blob/881e966aced6f0f208f43d7b7e7e55bc680f8fa5/src/transformers/models/mistral/modeling_mistral.py? If you share the diff of this one and a few before / after, I can tell you which exactly

The changes should only be related to a few additional kwargs for bidirectional attention & allowing labels to be passed through

spookyQubit commented 2 months ago

Thanks @Muennighoff for getting back so quickly.

Following is the diff between modeling_mistral.py from transformers (commit 881e966 which you pointed above) and scripts/modeling_mistral_gritlm.py:

$ diff mistral_hf_881e966.py modeling_mistral_gritlm.py
22a23
> import os
34c35
< from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
---
> from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa, _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
205,206c206,207
<                 f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
<                 "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
---
>                 f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
>                 "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
217d217
<         self.is_causal = True
344a345
>         is_causal: bool = True,
461a463
>             is_causal=is_causal,
481a484
>         is_causal=True,
505c508
<             causal = self.is_causal
---
>             causal = is_causal
508c511
<             causal = self.is_causal and query_length != 1
---
>             causal = is_causal and query_length != 1
631a635
>         is_causal: bool = True,
645a650
>                 is_causal=is_causal,
692c697
<             is_causal=self.is_causal and attention_mask is None and q_len > 1,
---
>             is_causal=is_causal and attention_mask is None and q_len > 1,
728a734
>         is_causal: Optional[bool] = True,
760a767
>             is_causal=is_causal,
939a947,949
>         labels: Optional[torch.LongTensor] = None,
>         instruction_lens=None,
>         is_causal: Optional[bool] = True,
996d1005
<             # 2d mask is passed through the layers
1001,1006c1010,1020
<             attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
<                 attention_mask,
<                 (batch_size, seq_length),
<                 inputs_embeds,
<                 past_key_values_length,
<             )
---
>             if is_causal:
>                 attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
>                     attention_mask,
>                     (batch_size, seq_length),
>                     inputs_embeds,
>                     past_key_values_length,
>                 )
>             else:
>                 attention_mask = _prepare_4d_attention_mask_for_sdpa(
>                     attention_mask, inputs_embeds.dtype
>                 )
1009,1015c1023,1036
<             attention_mask = _prepare_4d_causal_attention_mask(
<                 attention_mask,
<                 (batch_size, seq_length),
<                 inputs_embeds,
<                 past_key_values_length,
<                 sliding_window=self.config.sliding_window,
<             )
---
>             if is_causal:
>                 # Causal mask with -3.3895e+38 where no attention should be
>                 attention_mask = _prepare_4d_causal_attention_mask(
>                     attention_mask,
>                     (batch_size, seq_length),
>                     inputs_embeds,
>                     past_key_values_length,
>                     sliding_window=self.config.sliding_window,
>                 )
>             else:
>                 # Shape: batch_size, 1, query_length, key_value_length
>                 attention_mask = _prepare_4d_attention_mask(
>                     attention_mask, inputs_embeds.dtype
>                 )
1036a1058
>                     is_causal,
1045a1068
>                     is_causal=is_causal,
1135,1136c1158,1159
<         >>> model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")
<         >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
---
>         >>> model = MistralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>         >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1163a1187
>             labels=labels,
1171c1195
<         if labels is not None:
---
>         if (labels is not None) and (input_ids.shape[1] > 1):
1176a1201,1203
>             # For deterministic loss w/ gradacc:
>             #loss_fct = CrossEntropyLoss(reduction="none")
>             loss_fct = CrossEntropyLoss(reduction="sum")            
1181a1209,1216
>             # For deterministic loss w/ gradacc:
>             #loss = loss_fct(shift_logits, shift_labels).sum() / input_ids.shape[0]
>             # Problem with below is
>             # e.g. if we have 30 tokens, now we split them in two batches with 20 & 10
>             # Then we get the losses 60 and 40 and average them
>             # We get (3 + 4)/2 = 3.5
>             # Meanwhile if we did it in one we would be doing 100 / 30 = 3.333
>             loss = loss_fct(shift_logits, shift_labels) / attention_mask.sum()
1210c1245
<             # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
---
>             # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
1247a1283
>                 "labels": kwargs.get("labels"),

Notice that there are changes related to the loss: loss_fct = CrossEntropyLoss(reduction="sum") and labels: "labels": kwargs.get("labels") which seem to be unrelated to is_causal. Are these changes needed/made specifically for gritlm?

Muennighoff commented 2 months ago

Looking at the typo-related differences i think it was https://github.com/huggingface/transformers/blob/3b7675b2b844b02d4821b827871a21ad16dd446c/src/transformers/models/mistral/modeling_mistral.py then

The label changes & instruction_lens kwarg were done for the ablations of mixing bidir & causal attn. We masked out parts based on the labels, but removed the code later as none of them helped (can still share it if helpful but not very clean).

Screenshot 2024-04-29 at 1 05 27 PM

The loss changes were also Grit-related but they're irrelevant now as that code is no longer used since labels is popped here (i.e. it will be None) https://github.com/ContextualAI/gritlm/blob/08ad444616a494a6265d07afc0d077fb68a9f2a7/gritlm/training/model.py#L188 and then loss is computed via this func https://github.com/ContextualAI/gritlm/blob/08ad444616a494a6265d07afc0d077fb68a9f2a7/gritlm/training/model.py#L187

spookyQubit commented 2 months ago

Thanks a lot @Muennighoff . Really appreciate your help. I am trying to modify llama/mistral to have bidirectional attention and the above pointers will definitely help me.

louieworth commented 1 month ago

@spookyQubit Have you done with your modification for bidirectional attention with llama? I really appreciate it if you could share your codebase and then I can double-check with mine. https://github.com/ContextualAI/gritlm/issues/19#issuecomment-2134298261

spookyQubit commented 1 month ago

Hi @louieworth, with help from @Muennighoff, i was able to get it to work. However, the code requires too many changes that I am not comfortable to know if I have a bug or not. However, instead of (almost) changing transformers source code, what one can do is something simpler as follows:

from peft import PeftModel
from torch import nn
from transformers import LlamaModel, LlamaConfig, LlamaPreTrainedModel
from transformers.models.llama.modeling_llama import (
    LlamaRMSNorm,
    LlamaDecoderLayer,
    LlamaMLP,
    LlamaFlashAttention2,
    LlamaForCausalLM
)

class ModifiedLlamaFlashAttention2(LlamaFlashAttention2):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.is_causal = False

MODIFIED_LLAMA_ATTENTION_CLASSES = {
    "flash_attention_2": ModifiedLlamaFlashAttention2,
}

class ModifiedLlamaDecoderLayer(LlamaDecoderLayer):
    def __init__(self, config: LlamaConfig, layer_idx: int):
        nn.Module.__init__(self)
        self.hidden_size = config.hidden_size

        self.self_attn = MODIFIED_LLAMA_ATTENTION_CLASSES[config._attn_implementation](
            config=config,
            layer_idx=layer_idx
        )

        self.mlp = LlamaMLP(config)
        self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

class LlamaBidirectionalModel(LlamaModel):
    def __init__(self, config: LlamaConfig):
        LlamaPreTrainedModel.__init__(self, config)
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        self.layers = nn.ModuleList(
            [
                ModifiedLlamaDecoderLayer(config, layer_idx)
                for layer_idx in range(config.num_hidden_layers)
            ]
        )
        self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.gradient_checkpointing = False

        # Initialize weights and apply final processing
        self.post_init()

This is a modified version of an implementation in the Llm2Vec codebase. The above assumes that we are always using flash_attention_2. The original Llm2Vec is more generic and supports other attention implementations.

louieworth commented 1 month ago

Hi @Muennighoff, I also noticed LLM2Vec. However, I think the is_causal is not an arg in forward() function but is set to is_causal=False with initialization. I actually want to make it an args in the forward forward() function as my tasks include both embedding and generation for the same model. Would you mind sending your implementation (bidirectional attention with llama) via my email: jiangli3859@gmail.com