zilliztech / GPTCache

Semantic cache for LLMs. Fully integrated with LangChain and llama_index.
https://gptcache.readthedocs.io
MIT License
6.96k stars 490 forks source link

[Bug]: GPT cache llama index integration #554

Open MeghaWalia-eco opened 9 months ago

MeghaWalia-eco commented 9 months ago

Current Behavior

I am trying to integrate GPTCache with llama index but LLM predictor is not accepting cache argument , to fix this i have created a cacheLLMPredictor class extended from LLM Predictor

from typing import Any
from llama_index import BasePromptTemplate
from llama_index.llm_predictor.base import LLMPredictor
from pydantic import BaseModel

from llama_index.llm_predictor.base import LLMPredictor
from pydantic import BaseModel

class CachedLLMPredictor(LLMPredictor):
    cache: Any  # Define the cache attribute

    class Config(BaseModel.Config):
        extra = 'allow'  # Allow extra attributes

    def __init__(self, cache, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.cache = cache

    def predict(self, prompt: BasePromptTemplate, **prompt_args: Any) -> str:
        cache_key = (prompt, tuple(sorted(prompt_args.items())))
        if cache_key in self.cache:
            return self.cache[cache_key]
        else:
            result = super().predict(prompt, **prompt_args)
            self.cache[cache_key] = result
            return result

But here self.cache[cache_key] = result and return self.cache[cache_key] lines are throwing errors and it is not working.

My actual problem is i have to add GPTCache to the existing LLamaIndex calls , my existing implementation is as below

query_engine = self.__llama_idx_svc.get_query_engine(tenant_id,
                                                             tenant_index,
                                                             tenant_config,
                                                             model_name,
                                                             node_postprocessors=node_postprocessors,
                                                             text_qa_template=text_qa_template,
                                                             synthesizer_mode=synthesizer_mode)

        if query_engine is None:
            raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Unable to construct Query Engine")

        response = query_engine.query(query)_qa_template,
                                            synthesizer_mode=synthesizer_mode)

        if query_engine is None:
            raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Unable to construct Query Engine")

        response = query_engine.query(query)
def __get_vector_store_qe(self,
                              tenant_index: Index,
                              tenant_config: Config,
                              model_name: Optional[str] = None,
                              node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
                              text_qa_template: Optional[BasePromptTemplate] = None,
                              synthesizer_mode: Optional[str] = None) -> BaseQueryEngine:
        # Load Index
        loaded_index = self.__index_svc.load_index(tenant_index, tenant_config, model_name)

        # Get Context & LLM
        service_ctx = loaded_index.service_context
        callback_manager = service_ctx.callback_manager

        # Get config
        related_series_prompt = tenant_config.config.get("qna_related_series_prompt", QNA_DEFAULT_RELATED_SERIES_PROMPT )

        # Configure prompt template
        text_qa_template = text_qa_template or None

        if str(tenant_index.name) == 'series_metadata': # TODO - Hardcoded, consider to move to Index Config
            text_qa_template = Prompt(
                related_series_prompt[0],
                prompt_type=PromptType.QUESTION_ANSWER
            )

        # postprocessing setup
        node_postprocessors = node_postprocessors or []

        retriever = VectorIndexRetriever(
            index=loaded_index, 
            similarity_top_k=20,
        )

        # Configure response synthesizer
        synthesizer_mode = synthesizer_mode or "compact"

        response_synthesizer = get_response_synthesizer(
            service_context=service_ctx,
            callback_manager=callback_manager,
            text_qa_template=text_qa_template,
            response_mode=synthesizer_mode,
        )

        # Assemble query engine
        return RetrieverQueryEngine(retriever=retriever,
                                    response_synthesizer=response_synthesizer,
                                    callback_manager=callback_manager,
                                    node_postprocessors=node_postprocessors)

def load_index(self, tenant_index: Index, tenant_config: Config, model_name: Optional[str] = None):

    if tenant_index.type == 'sql_store' or tenant_index.type == 'sql_store_with_meta':
        return self.__load_sql_index(tenant_index, tenant_config, model_name)
    else:
        return self.__load_vector_index(tenant_index, tenant_config, model_name)     

def get_content_func(data, **_):
    return data.get("prompt").split("Question")[-1]   

# TODO - Move to LlamaIndexSvc
def __load_vector_index(self,
                        tenant_index: Index,
                        tenant_config: Config,
                        model_name: Optional[str] = None):

    gptcache_obj = GPTCache(self.init_gptcache)

    docstore = MongoDocumentStore.from_uri(db_name=os.getenv("INDEX_DB_NAME"),
                                           uri=os.getenv("INDEX_MONGODB_URL"),
                                           namespace=tenant_index.docstore_namespace)

    vector_store = PGVectorStore.from_params(database=os.getenv("INDEX_DB_NAME"),
                                             host=self.__postgres_host,
                                             port=self.__postgres_port,
                                             user=self.__postgres_user,
                                             password=self.__postgres_pass,
                                             table_name=tenant_index.vector_store_table,
                                             embed_dim=tenant_index.embed_dim,
                                             hybrid_search=tenant_index.hybrid_search,
                                             text_search_config=tenant_index.text_search_config)

    storage_ctx = StorageContext.from_defaults(index_store=self.__index_store,
                                               docstore=docstore,
                                               vector_store=vector_store)

    # get config
    llm_temperature = tenant_config.config.get("llm_temperature", 0)
    llm_num_outputs = tenant_config.config.get("llm_num_outputs", None)
    llm_api_key = tenant_config.config.get("llm_api_key", "")
    llm_model = tenant_config.config.get("llm_model", "")
    if model_name and model_name in MODEL_NAME_MAP:
        llm_model = MODEL_NAME_MAP[model_name] 

    node_parser = SimpleNodeParser(
        text_splitter=TokenTextSplitter(
            chunk_size=tenant_index.chunk_size,
            chunk_overlap=tenant_index.max_chunk_overlap,
            callback_manager=self.__get_llm_callback_manager(),
        ),
    )

    embed_model = OpenAIEmbedding(api_key=llm_api_key)
    llm = OpenAI(temperature=llm_temperature,
                 max_tokens=llm_num_outputs,
                 model=llm_model,
                 api_key=llm_api_key)

    service_ctx = ServiceContext.from_defaults(llm_predictor=CachedLLMPredictor(llm=llm, cache=gptcache_obj),
                                               embed_model=embed_model,
                                               node_parser=node_parser,
                                               callback_manager=self.__get_llm_callback_manager())

    return load_index_from_storage(index_id=str(tenant_index.id),
                                   storage_context=storage_ctx,
                                   service_context=service_ctx)

Expected Behavior

need ti implement gpt caching in llm calls

Steps To Reproduce

above code

Environment

No response

Anything else?

No response

SimFG commented 9 months ago

You only need to deal with pydantic's check of attributes in the class, and naturally you can use GPTCache. Or you can build an openai proxy service and use GPTCache in the service.

MeghaWalia-eco commented 9 months ago

I am not getting any pydantic error but when i am trying to set or retrieve the cache key, I am getting errors But here self.cache[cache_key] = result and return self.cache[cache_key] lines are throwing errors and it is not working.

Can i get an example using above code on how to do that

SimFG commented 9 months ago

The error is ?

MeghaWalia-eco commented 9 months ago

File "C:\AILatestClone\EconomistDigitalSolutions\openai-hack\app\service\CachedLLMPredictor.py", line 22, in predict
if cache_key in self.cache: ^^^^^^^^^^^^^^^^^^^^^^^ TypeError: argument of type 'GPTCache' is not iterable

File "C:\AILatestClone\EconomistDigitalSolutions\openai-hack\app\service\CachedLLMPredictor.py", line 26, in predict
self.cache[cache_key] = result


TypeError: 'GPTCache' object does not support item assignment

I think i am not accessing the cache correctly
SachinGanesh commented 8 months ago

@MeghaWalia-eco

Were you able to solve this issue?

MeghaWalia-eco commented 8 months ago

@SachinGanesh No