kaistAI / LangBridge

[ACL 2024] LangBridge: Multilingual Reasoning Without Multilingual Supervision
https://aclanthology.org/2024.acl-long.405/
71 stars 8 forks source link

Different Transformer Version for Gemma2 #18

Open Kosei1227 opened 2 days ago

Kosei1227 commented 2 days ago

Hi, our team wants to test LangBridge on Gemma2 but as Gemma2 is not created for transformers==4.37.2, I created the LangBridge model on Gemma2 with the transformers==4.42.0. But the result is terrible. The model cannot even produce reasonable texts.

def get_template(template_name):
        if 'gemma2' in template_name:
            print("Gemma2 template")
            return (
                "<bos><start_of_turn>user"
                "{user_message}<end_of_turn>"
                "<start_of_turn>model"
            )

Also, we defined the custom langbridge model as seen below.

class LBGemma2(LBBaseModel):
    config: LangBridgeConfig

    def __init__(self, config: LangBridgeConfig, random_init=True):
        from transformers import Gemma2ForCausalLM, Gemma2Config, Gemma2Model
        super().__init__(config, random_init=random_init)

        if random_init:
            # Load Gemma2 config and create a model from scratch
            model_config = Gemma2Config.from_pretrained(config.lm)
            try:
                model_config.attn_implementation = 'flash_attention_2'  # If supported
                base_lm: Gemma2ForCausalLM = Gemma2ForCausalLM(config=model_config)
            except ImportError:
                print('Not using Flash Attention!')
                base_lm: Gemma2ForCausalLM = Gemma2ForCausalLM(config=model_config)
        else:
            # Load the pretrained model from Hugging Face
            print('Loading Gemma2 model from pretrained weights')
            try:
                base_lm: Gemma2ForCausalLM = Gemma2ForCausalLM.from_pretrained(
                    config.lm, use_flash_attention_2=True
                )
            except ImportError:
                print('Not using Flash Attention!')
                base_lm: Gemma2ForCausalLM = Gemma2ForCausalLM.from_pretrained(
                    config.lm
                )

        # Ensure that the hidden size in the configuration matches the model's hidden size
        assert self.config.dim_lm == base_lm.config.hidden_size, \
            f"specified {self.config.dim_lm=} in LangBridgeConfig, but {config.lm} has hidden size={base_lm.config.hidden_size}"

        # Set up the language model, language head, and embeddings
        self.lm: Gemma2Model = base_lm.model
        self.lm_head = base_lm.lm_head
        self.embeddings = base_lm.get_input_embeddings()

Could you tell us why the transformer version 4.37.2 is required for LangBridge?

Thank you

MattYoon commented 2 days ago

To be frank, I haven't taken a deep look in to this issue.

I think you can get some insights as to what happened by checking the update logs of Transformers. But at the moment, I don't have the capacity to debug LangBridge in newer versions of Transformers :( If you manage to find the issue, please let me know though.

Kosei1227 commented 2 days ago

Thank you so much for your immediate response. I will follow the updated changes of transformers to fix the issue.

Kosei1227 commented 2 days ago

Just to confirm, was tansformer=4.37.2 the latest version when you designed LangBridge?

I would like to know why the authors used the version in this paper.

MattYoon commented 2 days ago

Yeah, I believe that was the latest at the time.