kaistAI / LangBridge

[ACL 2024] LangBridge: Multilingual Reasoning Without Multilingual Supervision
https://aclanthology.org/2024.acl-long.405/
81 stars 7 forks source link

Different Transformer Version for Gemma2 #18

Closed Kosei1227 closed 3 weeks ago

Kosei1227 commented 1 month ago

Hi, our team wants to test LangBridge on Gemma2 but as Gemma2 is not created for transformers==4.37.2, I created the LangBridge model on Gemma2 with the transformers==4.42.0. But the result is terrible. The model cannot even produce reasonable texts.

def get_template(template_name):
        if 'gemma2' in template_name:
            print("Gemma2 template")
            return (
                "<bos><start_of_turn>user"
                "{user_message}<end_of_turn>"
                "<start_of_turn>model"
            )

Also, we defined the custom langbridge model as seen below.

class LBGemma2(LBBaseModel):
    config: LangBridgeConfig

    def __init__(self, config: LangBridgeConfig, random_init=True):
        from transformers import Gemma2ForCausalLM, Gemma2Config, Gemma2Model
        super().__init__(config, random_init=random_init)

        if random_init:
            # Load Gemma2 config and create a model from scratch
            model_config = Gemma2Config.from_pretrained(config.lm)
            try:
                model_config.attn_implementation = 'flash_attention_2'  # If supported
                base_lm: Gemma2ForCausalLM = Gemma2ForCausalLM(config=model_config)
            except ImportError:
                print('Not using Flash Attention!')
                base_lm: Gemma2ForCausalLM = Gemma2ForCausalLM(config=model_config)
        else:
            # Load the pretrained model from Hugging Face
            print('Loading Gemma2 model from pretrained weights')
            try:
                base_lm: Gemma2ForCausalLM = Gemma2ForCausalLM.from_pretrained(
                    config.lm, use_flash_attention_2=True
                )
            except ImportError:
                print('Not using Flash Attention!')
                base_lm: Gemma2ForCausalLM = Gemma2ForCausalLM.from_pretrained(
                    config.lm
                )

        # Ensure that the hidden size in the configuration matches the model's hidden size
        assert self.config.dim_lm == base_lm.config.hidden_size, \
            f"specified {self.config.dim_lm=} in LangBridgeConfig, but {config.lm} has hidden size={base_lm.config.hidden_size}"

        # Set up the language model, language head, and embeddings
        self.lm: Gemma2Model = base_lm.model
        self.lm_head = base_lm.lm_head
        self.embeddings = base_lm.get_input_embeddings()

Could you tell us why the transformer version 4.37.2 is required for LangBridge?

Thank you

MattYoon commented 1 month ago

To be frank, I haven't taken a deep look in to this issue.

I think you can get some insights as to what happened by checking the update logs of Transformers. But at the moment, I don't have the capacity to debug LangBridge in newer versions of Transformers :( If you manage to find the issue, please let me know though.

Kosei1227 commented 1 month ago

Thank you so much for your immediate response. I will follow the updated changes of transformers to fix the issue.

Kosei1227 commented 1 month ago

Just to confirm, was tansformer=4.37.2 the latest version when you designed LangBridge?

I would like to know why the authors used the version in this paper.

MattYoon commented 1 month ago

Yeah, I believe that was the latest at the time.

Kosei1227 commented 4 weeks ago

to be short, gemma2model did not return past caches, unlike other models. This issue has been resolved from transformer==4.5.0, but since the atypical attention structure(sliding + global). The hybrid caching system has a flaw and does not account for the encoder features we have beforehand. Thus, so far, I successfully modified the hybrid caching and LangBridge modeling. The inference generation runs without errors.

But, I face the issue that the model outputs almost the same answers every time which might be attributed to the

# Assume that there will always only be one token in the decoder inputs, assumption holds for existing HF models <- Really? i need to check
         input_ids = torch.LongTensor([self.lm_tokenizer.bos_token_id])
         input_ids = input_ids.repeat(enc_ids.shape[0], 1).to(self.device)
         attention_mask = torch.ones_like(input_ids)

Could you tell me why you made this inference choice? It's unusual to input only one token to the model to my knowledge.

Edit: to my knowledge, LangBridge adds soft prompts to the inputs, but it seems we only pass the soft prompts from encoders. I might misunderstand the methodology.

Edit2: nvm, I inspected the input ids and it correctly takes the next input ids. hmm, there seems to be another reason why I see mostly identical responses. Did you observe the almost identical responses in your research using LangBridge?

MattYoon commented 3 weeks ago

Did you observe the almost identical responses in your research using LangBridge?

No, I haven't seen that behavior.

The one token I passed to the LM is supposed to be the BOS token. I understand that it looks confusing, but I suspected that the LM will behave wierdly without the explicit BOS token since it was always trained to have one in the beginning.

Kosei1227 commented 3 weeks ago

I see. So, to my current understanding, LangBridge enforces the encoder-decoder model style without tokenizing the input texts with lm tokenizer.

Even though each input ids after iteration will be tokenized with lm tokenizer, I suspect losing lm tokenized inputs can impact the performance.

Did you attempt the prompt tuning approach to create soft prompts from the encoder?

MattYoon commented 3 weeks ago

Yeah, that's a good point. If I understood correctly, here's a followup work to LangBridge that improves on what you said.

https://arxiv.org/abs/2405.17386

Kosei1227 commented 3 weeks ago

oh, I see. Thank you so much for sharing the paper. yeah, that's exactly what I want. I will try this model soon.