run-llama / llama_index

LlamaIndex is a data framework for your LLM applications
https://docs.llamaindex.ai
MIT License
36.59k stars 5.23k forks source link

[Question]: Custom embedding max query length #7679

Closed austinmw closed 1 year ago

austinmw commented 1 year ago

Question Validation

Question

Hi, I'm using a langchain embedding model which has a max input length of 512, which is causing me to run into an error with VectorStoreIndex.from_documents(). What do I need to change to let the embedding pipeline know that my embedding model's max input length is 512?

from langchain.embeddings import BedrockEmbeddings
from langchain.llms.bedrock import Bedrock

from llama_index.llms import LangChainLLM
from llama_index import ServiceContext, set_global_service_context

embed_model = BedrockEmbeddings(
    client=bedrock,
    model_id='amazon.titan-e1t-medium',

)
llm_model = Bedrock(
    client=bedrock,
    model_id='amazon.titan-tg1-large',
)

llm = LangChainLLM(llm_model)

service_context = ServiceContext.from_defaults(llm=llm, embed_model=embed_model)
data = SimpleDirectoryReader(input_dir="./data/").load_data()
index = VectorStoreIndex.from_documents(data, service_context=service_context)
dosubot[bot] commented 1 year ago

🤖

Hi there,

Based on the context provided, it seems like you're running into an issue with the maximum input length of your embedding model when using the VectorStoreIndex.from_documents() function.

To resolve this, you can set the chunk_size_limit explicitly in the ServiceContext to a value that is less than the maximum context length allowed by your model. This solution was suggested in a similar issue #990 and it seems to have resolved the issue for some users. Here is how you can do it:

service_context = ServiceContext.from_defaults(
    llm=llm,
    embed_model=embed_model,
    chunk_size_limit=512
)

However, if the issue persists, you might need to modify the _get_text_embeddings and _aget_text_embeddings methods in the AdapterEmbeddingModel class, which is a wrapper around the BaseEmbedding class. This class overrides the _get_text_embedding and _aget_text_embedding methods of the BaseEmbedding class. Therefore, you would need to modify these methods in the AdapterEmbeddingModel class instead.

Here is an example of how you could modify the _get_text_embedding method in the AdapterEmbeddingModel class:

def _get_text_embedding(self, text: str) -> List[float]:
    if len(text.split()) > 512:
        text = ' '.join(text.split()[:512])
    text_embedding = self._base_embed_model._get_text_embedding(text)
    return text_embedding

And similarly, modify the _aget_text_embedding method:

async def _aget_text_embedding(self, text: str) -> List[float]:
    if len(text.split()) > 512:
        text = ' '.join(text.split()[:512])
    text_embedding = await self._base_embed_model._aget_text_embedding(text)
    return text_embedding

These modifications ensure that the text passed to the base embedding model does not exceed 512 words. Please note that this modification assumes that the length of a text is the number of its words, which might not be the case for your specific embedding model. You might need to adjust the way the length of a text is calculated based on your specific needs.

I hope this helps! Let me know if you have any other questions.

Sources


This response is meant to be useful and save you time. It is not meant to be a precise solution, but rather a starting point for your own research.

Help me be more useful! Please leave a 👍 if this is helpful and 👎 if it is irrelevant.

If you want to continue the conversation, start your reply with @dosu-bot.