jxmorris12 / vec2text

utilities for decoding deep representations (like sentence embeddings) back to text
Other
673 stars 75 forks source link

Add gtr-base api support #25

Closed ArvinZhuang closed 7 months ago

ArvinZhuang commented 8 months ago

Hi @jxmorris12 , I have changed api.py a bit and added an example for gtr-base with your puiblished gtr-nq-32 models. I have tested it and it works well :)

ArvinZhuang commented 8 months ago

BTW, I noticed that the embedding given by the SentenceTransformer lib (see example below) is different from the embedding given by get_gtr_embeddings that I provided (should use get_gtr_embeddings to be able to get your correction models working).

from sentence_transformers import SentenceTransformer
sentences = ["Jack Morris is a PhD student at Cornell Tech in New York City"]

model = SentenceTransformer('sentence-transformers/gtr-t5-base')
embeddings = model.encode(sentences, convert_to_tensor=True,)
print(embeddings[0][:10])

tensor([-0.0241,  0.0235, -0.0385,  0.0548, -0.0553, -0.0329,  0.0128,  0.0332,
         0.0071, -0.0482])

The above embedding is different from:

import vec2text
import torch
from transformers import AutoModel, AutoTokenizer, PreTrainedTokenizer, PreTrainedModel

def get_gtr_embeddings(text_list,
                       encoder: PreTrainedModel,
                       tokenizer: PreTrainedTokenizer) -> torch.Tensor:

    inputs = tokenizer(text_list,
                       return_tensors="pt",
                       max_length=128,
                       truncation=True,
                       padding="max_length",).to("cuda")

    with torch.no_grad():
        model_output = encoder(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
        hidden_state = model_output.last_hidden_state
        embeddings = vec2text.models.model_utils.mean_pool(hidden_state, inputs['attention_mask'])

    return embeddings

encoder = AutoModel.from_pretrained("sentence-transformers/gtr-t5-base").encoder.to("cuda")
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/gtr-t5-base")
corrector = vec2text.load_corrector("gtr-base")

embeddings = get_gtr_embeddings([
       "Jack Morris is a PhD student at Cornell Tech in New York City",
], encoder, tokenizer)

print(embeddings[0][:10])

tensor([-0.0148,  0.0272, -0.0089, -0.0040, -0.0253, -0.0294,  0.0267, -0.0552,
        -0.0283, -0.0457])

I figured out this is because SentenceTransformer('sentence-transformers/gtr-t5-base') has a Dense layer and a Normalization layer after the mean_pool: https://huggingface.co/sentence-transformers/gtr-t5-base/blob/main/modules.json), which are ignored by your training code. Should I worry about this?

jxmorris12 commented 7 months ago

This is a great addition – thank you!

And great call. This is a little sad because it means our model can't actually work properly on precomputed GTR embeddings that come from a different library. It also explains a tiny performance dropoff I noticed when using the GTR embeddings for retrieval in the defense section of our paper. Hopefully myself or someone else will get around to training a new model (probably can just fine-tune existing ones) for "correct" GTR embeddings. Thanks for pointing out this issue and sorry for any trouble it may have caused.