Open phartman-keysight opened 4 months ago
You should be able to just load it as follows:
model = GritLM("GritLM/emb_m7_nodes16_fast", torch_dtype="auto", mode='embedding')
and use it in the same way as GritLM/GritLM-7B
. Else finetuning GritLM/GritLM-7B
is probably just as fine - i don't know which one would perform better actually.
I know standard grit generally uses mean pooling of the last hidden state for embeddings. I know it can use weighted mean, CLS, or last token instead. I know mean pooling of the token embeddings is a common way to generate sentence embeddings, but I've also seen fully connected "pooler" layers that are just one final dense layer that generates the embedding. How did you decide to do mean pooling rather than basically have a "language head" and an "embedding head" that you just apply to the last hidden state for either output? (as opposed to language head and mean pooling?) Do you think there would be performance improvements if someone were to apply that approach?
So you would apply that head over the seq len? I.e. (batch size, seq len, hidden dim) -> (batch size, 1, hidden dim)?
The problem is that seq len may change depending on the sample. You'd likely have to pad it to always the same number of tokens & those padding tokens would then become part of the embedding as they're part of the matrix multiply which may hurt performance.
Yeah, I see what you're saying. The reason the language model head works is because it only uses the last token embedding to generate the logits for the next token. I didn't realize that, so if i want something comparable to that I wouldn't make an "embedding head" I would just use the last token approach which you already support.
I understand now, thanks for the response.
Continuing from our conversation in https://github.com/ContextualAI/gritlm/issues/13 I just think it needed a new ticket at this point.
I am trying to finetune embeddings only so I took your(@Muennighoff 's) recommendation of using GritLM/emb_m7_nodes16_fast but I don't see the embedding for the entire sentence/article only the token embeddings. Am I misunderstanding something?
The standard grit model is both a generative and an encoder so the forward function is generative and encode is the embedding. So I use model.encode(input_tokens, instruction) which returns a vector with shape (4096,) which works great. Using the model you recommended there is no generative part so I assumed forward is the embedding function and there is no encode function, right? The issue I'm hitting is that when i run model(input_tokens) i get back a tuple for a 4096 embedding for each token as oppose to a single embedding for the entire article. Should I be doing pooling on these or is there some other function I should use to get the embedding?
Here is some example code
Also the embeddings won't be the same since they are different models, but they result in similar similarity scores, right?