Closed spookyQubit closed 6 months ago
Great point - I think it was this one: https://github.com/huggingface/transformers/blob/881e966aced6f0f208f43d7b7e7e55bc680f8fa5/src/transformers/models/mistral/modeling_mistral.py? If you share the diff of this one and a few before / after, I can tell you which exactly
The changes should only be related to a few additional kwargs for bidirectional attention & allowing labels to be passed through
Thanks @Muennighoff for getting back so quickly.
Following is the diff between modeling_mistral.py
from transformers (commit 881e966 which you pointed above) and scripts/modeling_mistral_gritlm.py
:
$ diff mistral_hf_881e966.py modeling_mistral_gritlm.py
22a23
> import os
34c35
< from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
---
> from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa, _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
205,206c206,207
< f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
< "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
---
> f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
> "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
217d217
< self.is_causal = True
344a345
> is_causal: bool = True,
461a463
> is_causal=is_causal,
481a484
> is_causal=True,
505c508
< causal = self.is_causal
---
> causal = is_causal
508c511
< causal = self.is_causal and query_length != 1
---
> causal = is_causal and query_length != 1
631a635
> is_causal: bool = True,
645a650
> is_causal=is_causal,
692c697
< is_causal=self.is_causal and attention_mask is None and q_len > 1,
---
> is_causal=is_causal and attention_mask is None and q_len > 1,
728a734
> is_causal: Optional[bool] = True,
760a767
> is_causal=is_causal,
939a947,949
> labels: Optional[torch.LongTensor] = None,
> instruction_lens=None,
> is_causal: Optional[bool] = True,
996d1005
< # 2d mask is passed through the layers
1001,1006c1010,1020
< attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
< attention_mask,
< (batch_size, seq_length),
< inputs_embeds,
< past_key_values_length,
< )
---
> if is_causal:
> attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
> attention_mask,
> (batch_size, seq_length),
> inputs_embeds,
> past_key_values_length,
> )
> else:
> attention_mask = _prepare_4d_attention_mask_for_sdpa(
> attention_mask, inputs_embeds.dtype
> )
1009,1015c1023,1036
< attention_mask = _prepare_4d_causal_attention_mask(
< attention_mask,
< (batch_size, seq_length),
< inputs_embeds,
< past_key_values_length,
< sliding_window=self.config.sliding_window,
< )
---
> if is_causal:
> # Causal mask with -3.3895e+38 where no attention should be
> attention_mask = _prepare_4d_causal_attention_mask(
> attention_mask,
> (batch_size, seq_length),
> inputs_embeds,
> past_key_values_length,
> sliding_window=self.config.sliding_window,
> )
> else:
> # Shape: batch_size, 1, query_length, key_value_length
> attention_mask = _prepare_4d_attention_mask(
> attention_mask, inputs_embeds.dtype
> )
1036a1058
> is_causal,
1045a1068
> is_causal=is_causal,
1135,1136c1158,1159
< >>> model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")
< >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
---
> >>> model = MistralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
> >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1163a1187
> labels=labels,
1171c1195
< if labels is not None:
---
> if (labels is not None) and (input_ids.shape[1] > 1):
1176a1201,1203
> # For deterministic loss w/ gradacc:
> #loss_fct = CrossEntropyLoss(reduction="none")
> loss_fct = CrossEntropyLoss(reduction="sum")
1181a1209,1216
> # For deterministic loss w/ gradacc:
> #loss = loss_fct(shift_logits, shift_labels).sum() / input_ids.shape[0]
> # Problem with below is
> # e.g. if we have 30 tokens, now we split them in two batches with 20 & 10
> # Then we get the losses 60 and 40 and average them
> # We get (3 + 4)/2 = 3.5
> # Meanwhile if we did it in one we would be doing 100 / 30 = 3.333
> loss = loss_fct(shift_logits, shift_labels) / attention_mask.sum()
1210c1245
< # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
---
> # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
1247a1283
> "labels": kwargs.get("labels"),
Notice that there are changes related to the loss: loss_fct = CrossEntropyLoss(reduction="sum")
and labels: "labels": kwargs.get("labels")
which seem to be unrelated to is_causal
. Are these changes needed/made specifically for gritlm?
Looking at the typo-related differences i think it was https://github.com/huggingface/transformers/blob/3b7675b2b844b02d4821b827871a21ad16dd446c/src/transformers/models/mistral/modeling_mistral.py then
The label changes & instruction_lens
kwarg were done for the ablations of mixing bidir & causal attn. We masked out parts based on the labels, but removed the code later as none of them helped (can still share it if helpful but not very clean).
The loss changes were also Grit-related but they're irrelevant now as that code is no longer used since labels is popped here (i.e. it will be None) https://github.com/ContextualAI/gritlm/blob/08ad444616a494a6265d07afc0d077fb68a9f2a7/gritlm/training/model.py#L188 and then loss is computed via this func https://github.com/ContextualAI/gritlm/blob/08ad444616a494a6265d07afc0d077fb68a9f2a7/gritlm/training/model.py#L187
Thanks a lot @Muennighoff . Really appreciate your help. I am trying to modify llama/mistral to have bidirectional attention and the above pointers will definitely help me.
@spookyQubit Have you done with your modification for bidirectional attention
with llama
? I really appreciate it if you could share your codebase and then I can double-check with mine. https://github.com/ContextualAI/gritlm/issues/19#issuecomment-2134298261
Hi @louieworth, with help from @Muennighoff, i was able to get it to work. However, the code requires too many changes that I am not comfortable to know if I have a bug or not. However, instead of (almost) changing transformers source code, what one can do is something simpler as follows:
from peft import PeftModel
from torch import nn
from transformers import LlamaModel, LlamaConfig, LlamaPreTrainedModel
from transformers.models.llama.modeling_llama import (
LlamaRMSNorm,
LlamaDecoderLayer,
LlamaMLP,
LlamaFlashAttention2,
LlamaForCausalLM
)
class ModifiedLlamaFlashAttention2(LlamaFlashAttention2):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.is_causal = False
MODIFIED_LLAMA_ATTENTION_CLASSES = {
"flash_attention_2": ModifiedLlamaFlashAttention2,
}
class ModifiedLlamaDecoderLayer(LlamaDecoderLayer):
def __init__(self, config: LlamaConfig, layer_idx: int):
nn.Module.__init__(self)
self.hidden_size = config.hidden_size
self.self_attn = MODIFIED_LLAMA_ATTENTION_CLASSES[config._attn_implementation](
config=config,
layer_idx=layer_idx
)
self.mlp = LlamaMLP(config)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
class LlamaBidirectionalModel(LlamaModel):
def __init__(self, config: LlamaConfig):
LlamaPreTrainedModel.__init__(self, config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[
ModifiedLlamaDecoderLayer(config, layer_idx)
for layer_idx in range(config.num_hidden_layers)
]
)
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
This is a modified version of an implementation in the Llm2Vec codebase. The above assumes that we are always using flash_attention_2
. The original Llm2Vec is more generic and supports other attention implementations.
Hi @Muennighoff, I also noticed LLM2Vec. However, I think the is_causal
is not an arg in forward()
function but is set to is_causal=False
with initialization. I actually want to make it an args
in the forward forward()
function as my tasks include both embedding and generation for the same model. Would you mind sending your implementation (bidirectional attention with llama) via my email: jiangli3859@gmail.com
Hi @louieworth, with help from @Muennighoff, i was able to get it to work. However, the code requires too many changes that I am not comfortable to know if I have a bug or not. However, instead of (almost) changing transformers source code, what one can do is something simpler as follows:
from peft import PeftModel from torch import nn from transformers import LlamaModel, LlamaConfig, LlamaPreTrainedModel from transformers.models.llama.modeling_llama import ( LlamaRMSNorm, LlamaDecoderLayer, LlamaMLP, LlamaFlashAttention2, LlamaForCausalLM ) class ModifiedLlamaFlashAttention2(LlamaFlashAttention2): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.is_causal = False MODIFIED_LLAMA_ATTENTION_CLASSES = { "flash_attention_2": ModifiedLlamaFlashAttention2, } class ModifiedLlamaDecoderLayer(LlamaDecoderLayer): def __init__(self, config: LlamaConfig, layer_idx: int): nn.Module.__init__(self) self.hidden_size = config.hidden_size self.self_attn = MODIFIED_LLAMA_ATTENTION_CLASSES[config._attn_implementation]( config=config, layer_idx=layer_idx ) self.mlp = LlamaMLP(config) self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) class LlamaBidirectionalModel(LlamaModel): def __init__(self, config: LlamaConfig): LlamaPreTrainedModel.__init__(self, config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList( [ ModifiedLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers) ] ) self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init()
This is a modified version of an implementation in the Llm2Vec codebase. The above assumes that we are always using
flash_attention_2
. The original Llm2Vec is more generic and supports other attention implementations.
May I know if I can use this directly as the modeling file for llama3.1-7B, so I could train it with GRIT pipeline? Thank you!
Hi @Muennighoff , thanks a lot for making the code public. I had a question regarding training a
mistral
embedding
model. Specifically, you note the following in the README:This is of great help. But it would also be very helpful to know the exact changes you had to make to the
transformers
modeling_mistral.py file.I tried doing a diff between the file in this repo vs the file in current transformers master. However, because the file in this repo must have been an edited version of an older transformer's file, it seems that there are multiple differences not related to
is_causal
.When your time permits, can you please point out the diff/changes you made to the
modeling_mistral.py
file for gritlm? Or point out the exact version of the transformers file which you started from (on top of which you made your edits).Thanks.