Muennighoff / sgpt

SGPT: GPT Sentence Embeddings for Semantic Search
https://arxiv.org/abs/2202.08904
MIT License
823 stars 51 forks source link

Theory question about the token weightings for symetric search. #9

Closed cm2435 closed 1 year ago

cm2435 commented 1 year ago

First things first, I loved reading your paper. Was clear, concise and has great implications for semantic search going forward. Cannot compiment highly enough!

One question. I would like to make use of a similar method to get semantic embedding for non GPT auto regressive language models. In the paper I read

The causal attention mask in an auto-regressive decoder transformer, tokens do not attend to
future tokens like in an encoder transformer. Hence, only the last token has attended to all tokens in a
sequence. To account for this information mismatch, we propose to give later tokens a higher weight
using a position-weighted mean pooling method:
v =
S∑
i=1
wihi where wi = i
∑S
i=1 i (2)
where S is the sequence length, hi the ith hidden state and v the query or document embedding. We
compare weighted mean pooling with last token pooling, where the hidden state of the final token is
the embedding, and regular mean pooling

This trick is really neat, but I was wondering if this would work for autoregressive decoder only models that use a causal language model loss, for example the XGLM model set? https://huggingface.co/facebook/xglm-564M

How about for autoregressive LM's that do not make use of causal language model losses, but instead use next-token prediction language modeling? Such as the CodeGen model set? https://arxiv.org/pdf/2203.13474.pdf [if you are unfamilar, the training section is 2.2 :) ]

I understand if there is not a clear answer to these questions, but I would love to hear your thoughts either way. Thanks again!

Muennighoff commented 1 year ago

If I'm not mistaken next-token prediction language modeling == causal language model loss == the objective of pretrained SGPT models. They are all causal decoder-only models with the same loss objective, so yes weighted mean pooling should work well for all of them.

cm2435 commented 1 year ago

@Muennighoff So it is. Sorry, thanks for the clarification.