kaistAI / LangBridge

[ACL 2024] LangBridge: Multilingual Reasoning Without Multilingual Supervision
https://aclanthology.org/2024.acl-long.405/
81 stars 7 forks source link

How to produce ../training/embeddings/llava_flores_lm_XXX.npy' #12

Closed Kosei1227 closed 2 months ago

Kosei1227 commented 2 months ago

Hi, I'm trying to produce the embedding of Langbridge models to reproduce the paper. In Figure3 directory, we only have gather_emb.py which is used to produce the embedding space as a baseline.

Here are the embedding generation codes for LangBridge, but I cannot resolve the errors.

import os  # noqa

# os.environ['TRANSFORMERS_CACHE'] = '/mnt/sda/dongkeun/huggingface'  # noqa
# os.environ['HF_DATASETS_CACHE'] = '/mnt/sda/dongkeun/huggingface'  # noqa

from tqdm import tqdm
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import numpy as np
from langbridge import LangBridgeModel
from lm_eval.models.langbridge import LBSeq2SeqLM

if __name__ == '__main__':
    # print(lm_eval)
    # print(LBSeq2SeqLM)
    DEVICE = 'cuda:0'
    # LANGS = ['eng', 'fra', 'amh', 'ewe', 'hau', 'ibo', 'kin', 'lin', 'lug', 'orm', 'sna', 'sot', 'swa', 'twi', 'wol', 'xho', 'yor', 'zul']
    LANGS = [
        "eng_Latn",  # English
        "fra_Latn",  # French
        "amh_Ethi",  # Amharic
        "ewe_Latn",  # Ewe
        "hau_Latn",  # Hausa
        "ibo_Latn",  # Igbo
        "kin_Latn",  # Kinyarwanda
        "lin_Latn",  # Lingala
        "lug_Latn",  # Luganda
        "orm_Latn",  # Oromo
        "sna_Latn",  # Shona
        "sot_Latn",  # Sotho
        "swh_Latn",  # Swahili
        "twi_Latn",  # Twi
        "wol_Latn",  # Wolof
        "xho_Latn",  # Xhosa
        "yor_Latn",  # Yoruba
        "zul_Latn"   # Zulu
    ]
    args_enc_tokenizer = 'kaist-ai/langbridge_encoder_tokenizer'
    args_checkpoint_path = 'kaist-ai/metamath-langbridge-9b'

    try:
        enc_tokenizer = AutoTokenizer.from_pretrained(
            args_enc_tokenizer, use_fast=False)
    except:
        enc_tokenizer = AutoTokenizer.from_pretrained(
            args_enc_tokenizer, use_fast=True)

    try:
        lm_tokenizer = AutoTokenizer.from_pretrained(
            args_checkpoint_path, use_fast=False)
    except:
        lm_tokenizer = AutoTokenizer.from_pretrained(
            args_checkpoint_path, use_fast=True)

    if not enc_tokenizer.pad_token:
        enc_tokenizer.pad_token = enc_tokenizer.eos_token
    if not lm_tokenizer.pad_token:
        lm_tokenizer.pad_token = lm_tokenizer.eos_token

    model = LangBridgeModel.from_pretrained('kaist-ai/metamath-langbridge-9b')
    model = LBSeq2SeqLM(
        model=model,
        enc_tokenizer=enc_tokenizer,
        lm_tokenizer=lm_tokenizer,
        batch_size=1
    )

    model.eval()
    model.to(DEVICE)

    lm_tokenizer = AutoTokenizer.from_pretrained(
        'kaist-ai/metamath-langbridge-9b', use_fast=False)

    metamath_template = (
    "Below is an instruction that describes a task. "
    "Write a response that appropriately completes the request.\n\n"
    "### Instruction:\n{instruction}\n\n### Response:\n"
    )
    system_message = ""

    for LANG in LANGS:
        ds = load_dataset('Muennighoff/flores200', LANG)['dev']

        all_embs = []
        for example in tqdm(ds):
            sentence = example['sentence']

            text = metamath_template.format(instruction=sentence)

            tokens = lm_tokenizer(text, return_tensors='pt').to(DEVICE)
            enc_input_ids = tokens['input_ids']

            with torch.no_grad():
                # emb = model.get_input_embeddings()(enc_input_ids).squeeze()
                emb = model(
                    enc_input_ids, output_hidden_states=True).hidden_states[-1].squeeze()
            mean = torch.mean(emb, dim=0)
            all_embs.append(mean)

        all_embs_tensor = torch.stack(all_embs, dim=0)

        # cast to float32 then to numpy
        all_embs_tensor = all_embs_tensor.float().cpu().numpy()

        print(all_embs_tensor.shape)

        np.save(
            f'embeddings/baseline_flores_metamath_lb_{LANG}.npy', all_embs_tensor)

I got the following error.

$ bash gather_emb_lb_afri.sh /home/leelab-africanllm-2/miniconda3/envs/leia/lib/python3.11/site-packages/huggingface_hub/file_download.py:1150: FutureWarning: resume_download is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use force_download=True. warnings.warn( Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained. Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:03<00:00, 1.05it/s] Failed to place model onto specified device. This may be because the model is quantized via bitsandbytes. If the desired GPU is being used, this message is safe to ignore. Traceback (most recent call last): File "/home/leelab-africanllm-2/LangBridge/Figure3/gather_emb_lb_afri.py", line 71, in model.eval() ^^^^^^^^^^ AttributeError: 'LBSeq2SeqLM' object has no attribute 'eval'

If possible, would you share the codes to produce the embedding representation of LangBridge models?

Thank you!

MattYoon commented 2 months ago

Hey @Kosei1227, thanks for reporting.

I just added another file in Figure3 that gathers embeddings for the LangBridge models!