ContextualAI / gritlm

Generative Representational Instruction Tuning
https://arxiv.org/abs/2402.09906
MIT License
479 stars 33 forks source link

Grit model training scripts can be compatible with any other language model? #19

Closed 5joon2 closed 3 months ago

5joon2 commented 3 months ago

Nice work! Your team's work is very impressive :)

I have a question,

like Llama or QWEN, there are such LLM architectures very similar to Mistral but a little bit different at some layers architecture.

This repo's codebase can be compatible to any other pre-trained LLM architectures?

Muennighoff commented 3 months ago

Yes this repo is compatible with other LLMs; I have tested Llama & GPT-J and I expect Qwen to work too

5joon2 commented 3 months ago

I can't find llama architecture based GRIT model class (which uses bidirectional attention) in grit/scripts directory. Could you share your one?

Muennighoff commented 3 months ago

I think you can just use the Mistral scripts but change MistralDecoderLayer to LlamaDecoderLayer here https://github.com/ContextualAI/gritlm/blob/f0c3820e9dde0ea2beb0c4ede775eeaac3398eda/scripts/configs/config_8gpusfsdp_m7.yml#L14 & meta-llama/Llama-2-7b-hf to mistralai/Mistral-7B-v0.1 here https://github.com/ContextualAI/gritlm/blob/f0c3820e9dde0ea2beb0c4ede775eeaac3398eda/scripts/training/train_gritlm_7b.sh#L54

5joon2 commented 3 months ago

Got it Thank you for advice. After did some experiments, I would share results 👍

phartman-keysight commented 1 month ago

@Muennighoff Is there any reason we couldn't do this for llama 3 as well?

Is it reasonable to think that the LLM that has the best performance at generating answers would also be the best at retrieving documents after it's been finetuned with the GRIT approach? Or are there other factors I'm not taking into account that play a more important role in embedding generation?

Muennighoff commented 1 month ago

Is there any reason we couldn't do this for llama 3 as well?

You can do it with Llama3

Is it reasonable to think that the LLM that has the best performance at generating answers would also be the best at retrieving documents after it's been finetuned with the GRIT approach?

Generally, yes. I'd expect that finetuning Llama3 with GRIT to be better than the current GritLM-7B.

louieworth commented 1 month ago

Hi, I am confused about the implementation of bidirectional attention for other LLMs. such as llama.

@5joon2 Have you implemented the transformers/models/llama/modeling_llama.py code? If the answer is yes, please share it with me. I also found that LLM2Vec has implemented the bidirectional attention, but I think is_causal is not an available arg for forward() in their implementation. I do want is_causal to become an arg for every forward to control whether to use bidirectional attention or not to adapt both generation and embedding tasks at the same time.

Hi @Muennighoff, based on your response to adapt bidirectional attention to llama, I am confused about it. Don't we need to change the code for transformers/models/llama/modeling_llama.py.

Muennighoff commented 1 month ago

Yes in order to use bidirectional attention you also need to change the model's modeling code as mentioned here: https://github.com/ContextualAI/gritlm/issues/34#issuecomment-2123719509

But if you don't care about bidirectional attention you can use all the other models as is.

louieworth commented 1 month ago

Thanks @Muennighoff. I have made a comparison with your implemented modeling_mistral_gritlm.py and vanilla modeling_mistral.py, but I have found that there are more than 20+ modifications regarding is_causal function and was totally lost.

I am not an expert in transformers, would you mind providing us with a script for modeling_llama_gritlm.py like modeling_mistral_gritlm.py with the arg=is_causal?

Muennighoff commented 1 month ago

See this issue: https://github.com/ContextualAI/gritlm/issues/32#issuecomment-2083769904 ; you need to compare w/ https://github.com/huggingface/transformers/blob/3b7675b2b844b02d4821b827871a21ad16dd446c/src/transformers/models/mistral/modeling_mistral.py