jxmorris12 / vec2text

utilities for decoding deep representations (like sentence embeddings) back to text
Other
673 stars 75 forks source link

Were the gtr models trained without normalization? #43

Closed Jackmin801 closed 5 months ago

Jackmin801 commented 5 months ago

For the gtr set of models (jxm/gtr__nq__32, jxm/gtr__nq__32__correct), were the models trained on sentence-transformers/gtr-t5-base embeddings without normalization? In the readme example, the function used to compute embeddings does not contain a normalization:

def get_gtr_embeddings(text_list,
                       encoder: PreTrainedModel,
                       tokenizer: PreTrainedTokenizer) -> torch.Tensor:

    inputs = tokenizer(text_list,
                       return_tensors="pt",
                       max_length=128,
                       truncation=True,
                       padding="max_length",).to("cuda")

    with torch.no_grad():
        model_output = encoder(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
        hidden_state = model_output.last_hidden_state
        embeddings = vec2text.models.model_utils.mean_pool(hidden_state, inputs['attention_mask'])

    return embeddings

This is weird to me because I was under the impression from the paper that the optimization was to minimize cosine distance which only cares about direction and is invariant to normalization. However, it seems that when I normalize the embeddings passed to the corrector, the results degrade.

jxmorris12 commented 5 months ago

Unfortunately yes. We use the huggingface GTR implementation which doesn't pool properly for some reason. Someone trained a model with the proper embeddings, which are (I think) normalized. I doubt it'll make much of a difference in results, provided there isn't a train-test mismatch. See #28.

Jackmin801 commented 5 months ago

Ok thanks!