u-masao / embed-text-recommender

MIT License
1 stars 0 forks source link

Model 改善: Llama2 系の Embedding モデルに対応したい #5

Open u-masao opened 9 months ago

u-masao commented 9 months ago

Llama2 系の Embedding モデルに対応することで、よりリッチな埋め込み空間が作れるかもしれない。 日本語の継続事前学習モデルを利用して、Last Hidden Layer 等から埋め込みを取得する方法が考えられる。 他のモデルと比較検討できるようにすること。

u-masao commented 9 months ago

調べたところいくつか先人の情報がある。

モデルの解説

u-masao commented 9 months ago

Llama2 といえば Swallow 7b も有力候補かも。

u-masao commented 9 months ago

Swallow 7b は、モデルバイナリが 20 GB もある。運用するのは辛そう。

u-masao commented 9 months ago

Colab V100 でもテストコードが動かないのでしばらく放置。

このあたりの情報が有用な気がするが、再現できないのでなんとも言えない。

Weighted-Mean-Pooling

Llama is a decoder with left-to-right attention. The idea behind weighted-mean_pooling is that the tokens at the end of the sentence should contribute more than the tokens at the beginning of the sentence because their weights are contextualized with the previous tokens, while the tokens at the beginning have far less context representation.

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

model_id = "meta-llama/Llama-2-7b-chat-hf"

t = AutoTokenizer.from_pretrained(model_id)
t.pad_token = t.eos_token
m = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", device_map="auto" )
m.eval()

texts = [
    "this is a test",
    "this is another test case with a different length",
]
t_input = t(texts, padding=True, return_tensors="pt")

with torch.no_grad():
    last_hidden_state = m(**t_input, output_hidden_states=True).hidden_states[-1]

weights_for_non_padding = t_input.attention_mask * torch.arange(start=1, end=last_hidden_state.shape[1] + 1).unsqueeze(0)

sum_embeddings = torch.sum(last_hidden_state * weights_for_non_padding.unsqueeze(-1), dim=1)
num_of_none_padding_tokens = torch.sum(weights_for_non_padding, dim=-1).unsqueeze(-1)
sentence_embeddings = sum_embeddings / num_of_none_padding_tokens

print(t_input.input_ids)
print(weights_for_non_padding)
print(num_of_none_padding_tokens)
print(sentence_embeddings.shape)