run-llama / llama_index

LlamaIndex is a data framework for your LLM applications
https://docs.llamaindex.ai
MIT License
34.71k stars 4.9k forks source link

How to use the same sub nodes for densex second time #14297

Open LikhithRishi opened 2 months ago

LikhithRishi commented 2 months ago

Question Validation

Question

`class DenseXRetrievalPack(BaseLlamaPack): def init( self, documents: List[Document], proposition_llm: Optional[LLM] = None, query_llm: Optional[LLM] = None, embed_model: Optional[BaseEmbedding] = None, text_splitter: TextSplitter = SentenceSplitter(), vector_store: Optional[ElasticsearchStore] = None, similarity_top_k: int = 4, ) -> None: """Init params.""" self._proposition_llm = llm

    embed_model = embed_model

    nodes = text_splitter.get_nodes_from_documents(documents)
    print(nodes)
    sub_nodes = self._gen_propositions(nodes)
    print(sub_nodes,"greg")
    all_nodes = nodes + sub_nodes
    all_nodes_dict = {n.node_id: n for n in all_nodes}

    service_context = ServiceContext.from_defaults(
        llm=query_llm ,
        embed_model=embed_model,
        num_output=self._proposition_llm.metadata.num_output,
    )
    '''
    if os.path.exists('./elastic_db'):
        print("From elasticsearch")
        self.vector_index = VectorStoreIndex.from_vector_store(vector_store,service_context=service_context)
    else:
        storage_context = StorageContext.from_defaults(vector_store=vector_store)
        self.vector_index = VectorStoreIndex(
             all_nodes, service_context=service_context, show_progress=True,storage_context=storage_context
             )
        os.mkdir("elastic_db")
    '''
    if os.path.exists('./chroma_db'):
        chroma_client = chromadb.PersistentClient(path="./chroma_db")
        chroma_collection = chroma_client.get_or_create_collection("quickstart")
        vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
        storage_context = StorageContext.from_defaults(vector_store=vector_store)
        self.vector_index = VectorStoreIndex.from_vector_store(vector_store,service_context=service_context)
    else:
       chroma_client = chromadb.PersistentClient(path="./chroma_db")
       chroma_collection = chroma_client.get_or_create_collection("quickstart")
       vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
       storage_context = StorageContext.from_defaults(vector_store=vector_store)
       self.vector_index = VectorStoreIndex(
            all_nodes, service_context=service_context, show_progress=True,storage_context=storage_context,store_nodes_override=True
            )
    self.retriever = RecursiveRetriever(
        "vector",
        retriever_dict={
            "vector": self.vector_index.as_retriever(
                similarity_top_k=similarity_top_k
            )
        },
        node_dict=all_nodes_dict,
    )

    self.query_engine = RetrieverQueryEngine.from_args(
        self.retriever, service_context=service_context
    )

async def _aget_proposition(self, node: TextNode) -> List[TextNode]:
    """Get proposition."""
    inital_output = await self._proposition_llm.apredict(
        PROPOSITIONS_PROMPT, node_text=node.text
    )
    outputs = inital_output.split("\n")

    all_propositions = []

    for output in outputs:
        if not output.strip():
            continue
        if not output.strip().endswith("]"):
            if not output.strip().endswith('"') and not output.strip().endswith(
                ","
            ):
                output = output + '"'
            output = output + " ]"
        if not output.strip().startswith("["):
            if not output.strip().startswith('"'):
                output = '"' + output
            output = "[ " + output

        try:
            propositions = json.loads(output)
        except Exception:
            # fallback to yaml
            try:
                propositions = yaml.safe_load(output)
            except Exception:
                # fallback to next output
                continue

        if not isinstance(propositions, list):
            continue

        all_propositions.extend(propositions)

    assert isinstance(all_propositions, list)
    nodes = [TextNode(text=prop) for prop in all_propositions if prop]

    return [IndexNode.from_text_node(n, node.node_id) for n in nodes]

def _gen_propositions(self, nodes: List[TextNode]) -> List[TextNode]:
    """Get propositions."""
    sub_nodes = asyncio.run(
        run_jobs(
            [self._aget_proposition(node) for node in nodes],
            show_progress=True,
            workers=8,
        )
    )
    print(sub_nodes)

    # Flatten list
    return [node for sub_node in sub_nodes for node in sub_node]

def get_modules(self) -> Dict[str, Any]:
    """Get modules."""
    return {
        "query_engine": self.query_engine,
        "retriever": self.retriever,
    }

def run(self, query_str: str, **kwargs: Any) -> RESPONSE_TYPE:
    """Run the pipeline."""
    return self.query_engine.query(query_str)`

Our aim is to store the embeddings created on chromadb and use the same embeddings 2nd time.But in RetrievalQueryEngine all_nodes_dict is not fixed as sub nodes generateb by llm is not fixed.How to store and use the same sub nodes 2nd time

dosubot[bot] commented 2 months ago

To ensure that the same sub-nodes are used consistently across multiple runs, you need to store the generated sub-nodes and reload them when needed. Here’s a way to achieve this:

  1. Store the sub-nodes: Serialize the sub-nodes to a file after they are generated.
  2. Load the sub-nodes: Deserialize the sub-nodes from the file when initializing the DenseXRetrievalPack class.

Here’s an example of how you can modify your class to include this functionality:

import os
import json
from llama_index import ServiceContext, set_global_service_context, RecursiveRetriever, RetrieverQueryEngine, VectorStoreIndex, StorageContext
from llama_index.embeddings.openai import OpenAIEmbedding, OpenAIEmbeddingMode, OpenAIEmbeddingModelType
from llama_index.vector_stores import ChromaVectorStore
from llama_index.text_splitters import SentenceSplitter
from llama_index.nodes import TextNode, IndexNode
from llama_index.llms import LLM
from typing import List, Optional, Dict, Any
import asyncio
import chromadb

class DenseXRetrievalPack(BaseLlamaPack):
    def __init__(
        self,
        documents: List[Document],
        proposition_llm: Optional[LLM] = None,
        query_llm: Optional[LLM] = None,
        embed_model: Optional[BaseEmbedding] = None,
        text_splitter: TextSplitter = SentenceSplitter(),
        vector_store: Optional[ElasticsearchStore] = None,
        similarity_top_k: int = 4,
    ) -> None:
        """Init params."""
        self._proposition_llm = proposition_llm

        embed_model = embed_model
        nodes = text_splitter.get_nodes_from_documents(documents)
        sub_nodes = self._load_or_generate_sub_nodes(nodes)
        all_nodes = nodes + sub_nodes
        all_nodes_dict = {n.node_id: n for n in all_nodes}

        service_context = ServiceContext.from_defaults(
            llm=query_llm,
            embed_model=embed_model,
            num_output=self._proposition_llm.metadata.num_output,
        )

        if os.path.exists('./chroma_db'):
            chroma_client = chromadb.PersistentClient(path="./chroma_db")
            chroma_collection = chroma_client.get_or_create_collection("quickstart")
            vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
            storage_context = StorageContext.from_defaults(vector_store=vector_store)
            self.vector_index = VectorStoreIndex.from_vector_store(vector_store, service_context=service_context)
        else:
            chroma_client = chromadb.PersistentClient(path="./chroma_db")
            chroma_collection = chroma_client.get_or_create_collection("quickstart")
            vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
            storage_context = StorageContext.from_defaults(vector_store=vector_store)
            self.vector_index = VectorStoreIndex(
                all_nodes, service_context=service_context, show_progress=True, storage_context=storage_context, store_nodes_override=True
            )

        self.retriever = RecursiveRetriever(
            "vector",
            retriever_dict={
                "vector": self.vector_index.as_retriever(similarity_top_k=similarity_top_k)
            },
            node_dict=all_nodes_dict,
        )

        self.query_engine = RetrieverQueryEngine.from_args(
            self.retriever, service_context=service_context
        )

    def _load_or_generate_sub_nodes(self, nodes: List[TextNode]) -> List[TextNode]:
        """Load sub-nodes from file or generate them if not available."""
        sub_nodes_file = 'sub_nodes.json'
        if os.path.exists(sub_nodes_file):
            with open(sub_nodes_file, 'r') as f:
                sub_nodes_data = json.load(f)
            sub_nodes = [TextNode.from_dict(data) for data in sub_nodes_data]
        else:
            sub_nodes = self._gen_propositions(nodes)
            with open(sub_nodes_file, 'w') as f:
                json.dump([node.to_dict() for node in sub_nodes], f)
        return sub_nodes

    async def _aget_proposition(self, node: TextNode) -> List[TextNode]:
        """Get proposition."""
        inital_output = await self._proposition_llm.apredict(
            PROPOSITIONS_PROMPT, node_text=node.text
        )
        outputs = inital_output.split("\n")

        all_propositions = []

        for output in outputs:
            if not output.strip():
                continue
            if not output.strip().endswith("]"):
                if not output.strip().endswith('"') and not output.strip().endswith(","):
                    output = output + '"'
                output = output + " ]"
            if not output.strip().startswith("["):
                if not output.strip().startswith('"'):
                    output = '"' + output
                output = "[ " + output

            try:
                propositions = json.loads(output)
            except Exception:
                # fallback to yaml
                try:
                    propositions = yaml.safe_load(output)
                except Exception:
                    # fallback to next output
                    continue

            if not isinstance(propositions, list):
                continue

            all_propositions.extend(propositions)

        assert isinstance(all_propositions, list)
        nodes = [TextNode(text=prop) for prop in all_propositions if prop]

        return [IndexNode.from_text_node(n, node.node_id) for n in nodes]

    def _gen_propositions(self, nodes: List[TextNode]) -> List[TextNode]:
        """Get propositions."""
        sub_nodes = asyncio.run(
            run_jobs(
                [self._aget_proposition(node) for node in nodes],
                show_progress=True,
                workers=8,
            )
        )

        # Flatten list
        return [node for sub_node in sub_nodes for node in sub_node]

    def get_modules(self) -> Dict[str, Any]:
        """Get modules."""
        return {
            "query_engine": self.query_engine,
            "retriever": self.retriever,
        }

    def run(self, query_str: str, **kwargs: Any) -> RESPONSE_TYPE:
        """Run the pipeline."""
        return self.query_engine.query(query_str)

This approach ensures that the sub-nodes are stored after the first generation and reused in subsequent runs, maintaining consistency in your retrieval process.

To continue talking to Dosu, mention @dosu.