dleemiller / WordLlama

Things you can do with the token embeddings of an LLM
MIT License
1.29k stars 41 forks source link

Need detailed example on how to extract the embedding model from LLM #14

Closed harshitv804 closed 1 month ago

harshitv804 commented 1 month ago

Hi, I saw there is an option to extract the embedding model from an LLM. Can this be applied to any LLM in hugging face? if yes, I need a detailed example of how to extract it and use it with WordLlama to generate vector embeddings.

Thank You.

dleemiller commented 1 month ago

Sure -- I can provide an example. The general pattern is:

So there is training involved -- although you can skip that part, I'm not sure how they will perform without being adapted for a task. Doing the projection also helps with reducing the embedding size.

There is no "one size fits all" approach to extracting the embedding, since it depends on how the creators of the model named their tensors. But if you let me know which model you're trying to extract from, I can write a tutorial around it. It's usually not very difficult to identify the correct tensor to extract, and sometimes there is a manifest.json you can read to identify the correct name. If the model is split into multiple safetensors files, it's most often in the first one.

harshitv804 commented 1 month ago

Hi, I would like tutorials using small LLMs like Gemma 2 2B on HF. Firstly I am a beginner only, and I will quickly tell my use case, plz correct me if I am wrong anywhere. We already have separate SOTA embedding models based on BERT embeddings, so that's fine only. I thought the embeddings in DECODER LLMs would be more accurate compared to BERT embeddings. And there is already something like LLM2Vec but it actually downloads the full model and then adds an adapter to it. So my doubt is, can we just simply extract till the embedding layer of the LLM save it separately and use it to generate vectors so that it would be fast, accurate and smaller in size compared to BERT embeddings?

Thanks

dleemiller commented 1 month ago

If you are thinking of extracting only the token embeddings, then no, I don't think you will accomplish what you're trying to do. The transformer layers in the decoder are important for learning contextual information from the text sequence. This is very important to forming good embedding representations.

WordLlama is meant more for simplicity and speed, and makes concessions to accomplish it. It's really a utility to lower the barriers to working with strings, to perform simple operations fast. You tokenize your text, each token is one of 32k vectors in the vocabulary, and you average them together. The tokens know nothing of their neighbors in the sequence.

I have added an example of how to extract the token embeddings: https://github.com/dleemiller/WordLlama/blob/main/tutorials/extract_token_embeddings.md

Hopefully that will help.