run-llama / llama_index

LlamaIndex is a data framework for your LLM applications
https://docs.llamaindex.ai
MIT License
35.86k stars 5.09k forks source link

[Bug]: PropertyGraphIndex in some cases throws an error when used to create a chat_engine or query_engine. This breaks the index.as_chat_engine() and index.as_query_engine() methods. #14557

Open gich2009 opened 3 months ago

gich2009 commented 3 months ago

Bug Description

When property graph index is used to create a chat_engine or a query_engine, it throws an error when some parameters are passed in. An easy fix for this is to add a **kwargs parameter to the underlying VectorStoreQuery(llama_index/core/vector_stores/types.py) in order to ignore any unnecessary arguments that are used elsewhere and are passed down to it.

Current Implementation: @dataclass class VectorStoreQuery: """Vector store query."""

query_embedding: Optional[List[float]] = None
similarity_top_k: int = 1
doc_ids: Optional[List[str]] = None
node_ids: Optional[List[str]] = None
query_str: Optional[str] = None
output_fields: Optional[List[str]] = None
embedding_field: Optional[str] = None

mode: VectorStoreQueryMode = VectorStoreQueryMode.DEFAULT

# NOTE: only for hybrid search (0 for bm25, 1 for vector search)
alpha: Optional[float] = None

# metadata filters
filters: Optional[MetadataFilters] = None

# only for mmr
mmr_threshold: Optional[float] = None

# NOTE: currently only used by postgres hybrid search
sparse_top_k: Optional[int] = None
# NOTE: return top k results from hybrid search. similarity_top_k is used for dense search top k
hybrid_top_k: Optional[int] = None

Fix that works: @dataclass class VectorStoreQuery: """Vector store query."""

query_embedding: Optional[List[float]] = None
similarity_top_k: int = 1
doc_ids: Optional[List[str]] = None
node_ids: Optional[List[str]] = None
query_str: Optional[str] = None
output_fields: Optional[List[str]] = None
embedding_field: Optional[str] = None
mode: VectorStoreQueryMode = VectorStoreQueryMode.DEFAULT
alpha: Optional[float] = None
filters: Optional[MetadataFilters] = None
mmr_threshold: Optional[float] = None
sparse_top_k: Optional[int] = None
hybrid_top_k: Optional[int] = None

def __init__(self, **kwargs):
    self.query_embedding = kwargs.get('query_embedding', self.query_embedding)
    self.similarity_top_k = kwargs.get('similarity_top_k', self.similarity_top_k)
    self.doc_ids = kwargs.get('doc_ids', self.doc_ids)
    self.node_ids = kwargs.get('node_ids', self.node_ids)
    self.query_str = kwargs.get('query_str', self.query_str)
    self.output_fields = kwargs.get('output_fields', self.output_fields)
    self.embedding_field = kwargs.get('embedding_field', self.embedding_field)
    self.mode = kwargs.get('mode', self.mode)
    self.alpha = kwargs.get('alpha', self.alpha)
    self.filters = kwargs.get('filters', self.filters)
    self.mmr_threshold = kwargs.get('mmr_threshold', self.mmr_threshold)
    self.sparse_top_k = kwargs.get('sparse_top_k', self.sparse_top_k)
    self.hybrid_top_k = kwargs.get('hybrid_top_k', self.hybrid_top_k)

Version

0.10.52

Steps to Reproduce

chat_engine = property_graph_index.as_chat_engine( chat_mode=chat_mode, llm=llm, similarity_top_k=similarity_top_k,

        ## All the parameters below throw an error because of VectorStoreQuery.__init__()
        # use_async=True, #This is already passed in by PGRetriever.
        service_context=service_context,
        response_mode=response_mode,
        verbose=verbose,
        max_function_calls=max_agent_iterations,
        max_iterations=max_agent_iterations,
        node_postprocessors=node_postprocessors,
    )

Relevant Logs/Tracbacks

response: AgentChatResponse = chat_engine.chat(input_text)
  File "/home/gich2009/venv/lib/python3.10/site-packages/llama_index/core/instrumentation/dispatcher.py", line 230, in wrapper
    result = func(*args, **kwargs)
  File "/home/gich2009/venv/lib/python3.10/site-packages/llama_index/core/callbacks/utils.py", line 41, in wrapper
    return func(self, *args, **kwargs)
  File "/home/gich2009/venv/lib/python3.10/site-packages/llama_index/core/chat_engine/condense_plus_context.py", line 292, in chat
    chat_messages, context_source, context_nodes = self._run_c3(
  File "/home/gich2009/venv/lib/python3.10/site-packages/llama_index/core/chat_engine/condense_plus_context.py", line 208, in _run_c3
    context_str, context_nodes = self._retrieve_context(condensed_question)
  File "/home/gich2009/venv/lib/python3.10/site-packages/llama_index/core/chat_engine/condense_plus_context.py", line 169, in _retrieve_context
    nodes = self._retriever.retrieve(message)
  File "/home/gich2009/venv/lib/python3.10/site-packages/llama_index/core/instrumentation/dispatcher.py", line 230, in wrapper
    result = func(*args, **kwargs)
  File "/home/gich2009/venv/lib/python3.10/site-packages/llama_index/core/base/base_retriever.py", line 243, in retrieve
    nodes = self._retrieve(query_bundle)
  File "/home/gich2009/venv/lib/python3.10/site-packages/llama_index/core/instrumentation/dispatcher.py", line 230, in wrapper
    result = func(*args, **kwargs)
  File "/home/gich2009/venv/lib/python3.10/site-packages/llama_index/core/indices/property_graph/retriever.py", line 52, in _retrieve
    return asyncio_run(self._aretrieve(query_bundle))
  File "/home/gich2009/venv/lib/python3.10/site-packages/llama_index/core/async_utils.py", line 47, in asyncio_run
    raise type(e)(f"Error running coroutine: {e!s}") from e
TypeError: Error running coroutine: VectorStoreQuery.__init__() got an unexpected keyword argument 'response_mode'
dosubot[bot] commented 3 months ago

To fix the error in the PropertyGraphIndex when creating a chat_engine or query_engine, you need to add a **kwargs parameter to the VectorStoreQuery class in llama_index/core/vector_stores/types.py. Here is the modified VectorStoreQuery class with the **kwargs parameter:

@dataclass
class VectorStoreQuery:
    """Vector store query."""

    query_embedding: Optional[List[float]] = None
    similarity_top_k: int = 1
    doc_ids: Optional[List[str]] = None
    node_ids: Optional[List[str]] = None
    query_str: Optional[str] = None
    output_fields: Optional[List[str]] = None
    embedding_field: Optional[str] = None
    mode: VectorStoreQueryMode = VectorStoreQueryMode.DEFAULT
    alpha: Optional[float] = None
    filters: Optional[MetadataFilters] = None
    mmr_threshold: Optional[float] = None
    sparse_top_k: Optional[int] = None
    hybrid_top_k: Optional[int] = None

    # Add **kwargs parameter
    **kwargs: Any

Additionally, ensure that the VectorContextRetriever class in llama_index/core/indices/property_graph/sub_retrievers/vector.py correctly passes these **kwargs when creating a VectorStoreQuery:

class VectorContextRetriever(BasePGRetriever):
    """A retriever that uses a vector store to retrieve nodes based on a query."""

    def __init__(
        self,
        graph_store: PropertyGraphStore,
        include_text: bool = True,
        embed_model: Optional[BaseEmbedding] = None,
        vector_store: Optional[VectorStore] = None,
        similarity_top_k: int = 4,
        path_depth: int = 1,
        filters: Optional[MetadataFilters] = None,
        **kwargs: Any
    ) -> None:
        self._retriever_kwargs = kwargs or {}
        self._embed_model = embed_model or Settings.embed_model
        self._similarity_top_k = similarity_top_k
        self._vector_store = vector_store
        self._path_depth = path_depth
        self._filters = filters

        super().__init__(graph_store=graph_store, include_text=include_text, **kwargs)

    def _get_vector_store_query(self, query_bundle: QueryBundle) -> VectorStoreQuery:
        if query_bundle.embedding is None:
            query_bundle.embedding = self._embed_model.get_agg_embedding_from_queries(
                query_bundle.embedding_strs
            )
        return VectorStoreQuery(
            query_embedding=query_bundle.embedding,
            similarity_top_k=self._similarity_top_k,
            filters=self._filters,
            **self._retriever_kwargs,
        )

    async def _aget_vector_store_query(
        self, query_bundle: QueryBundle
    ) -> VectorStoreQuery:
        if query_bundle.embedding is None:
            query_bundle.embedding = (
                await self._embed_model.aget_agg_embedding_from_queries(
                    query_bundle.embedding_strs
                )
            )

        return VectorStoreQuery(
            query_embedding=query_bundle.embedding,
            similarity_top_k=self._similarity_top_k,
            filters=self._filters,
            **self._retriever_kwargs,
        )

This ensures that any additional keyword arguments are properly passed through the VectorContextRetriever to the VectorStoreQuery [1][2].

To continue talking to Dosu, mention @dosu.

gich2009 commented 3 months ago

Another suggestion is that PGRetriever could also check if use_async has been passed to it as a kwargs parameter: Currrent Implementation: return PGRetriever(sub_retrievers, use_async=self._use_async, **kwargs)

Proposed Implementation: use_async = kwargs.pop("use_async", self._use_async) return PGRetriever(sub_retrievers, use_async=use_async, **kwargs)