run-llama / llama_index

LlamaIndex is a data framework for your LLM applications
https://docs.llamaindex.ai
MIT License
36.79k stars 5.28k forks source link

[Question]: top_k and context windows vs. Refinement process #10540

Closed jc1518 closed 9 months ago

jc1518 commented 9 months ago

Question Validation

Question

I have observed something interesting, and I don't quite understand what was happening under the hood. Can someone please share some insights. Hi again @logan-markewich , @jerryjliu , @ravi-theja any thoughts?

I did two tests, both with claude v2 (context window is 100k, I set max_tokens_to_sample = 4000) and llama-index (0.9.45). Vectorstore embeded chunk size is 512. and compact response mode.

In the first test, I set top_k=1. And from the following, I can see it made one LLM call to synthesize the response, which makes sense to me. And the statement Based on the provided context shows it did not go through refinement process.

Query: what is xyz?
**********
Trace: index_construction
**********
**********
Trace: query
    |_CBEventType.QUERY ->  1.114918 seconds
      |_CBEventType.RETRIEVE ->  1.100808 seconds
        |_CBEventType.EMBEDDING ->  1.065354 seconds
      |_CBEventType.SYNTHESIZE ->  0.013894 seconds
        |_CBEventType.TEMPLATING ->  3.4e-05 seconds
        |_CBEventType.LLM ->  0.0 seconds
    |_CBEventType.LLM ->  0.0 seconds
**********
Token usage for query: what is xyz?
 Embedding Tokens:  6 
 LLM Prompt Tokens:  1112 
 LLM Completion Tokens:  288 
 Total LLM Token Count:  1400 

Answer: Based on the provided context, xyz refers to....

In the second test, I set top_k=2. And from the following, I can see it made four LLM calls to synthesize the response. And the statement Based on the additional context provided, I would make the following minor updates to the existing answer shows it went through refinement process. Now my question is that - the context windows is 100K, and the response mode is compact. Isn't it supposed to be one LLM call?

Query: what is xyz?
**********
Trace: index_construction
**********
**********
Trace: query
    |_CBEventType.QUERY ->  15.731541 seconds
      |_CBEventType.RETRIEVE ->  0.942077 seconds
        |_CBEventType.EMBEDDING ->  0.906755 seconds
      |_CBEventType.SYNTHESIZE ->  14.789295 seconds
        |_CBEventType.TEMPLATING ->  3.7e-05 seconds
        |_CBEventType.LLM ->  5.473105 seconds
        |_CBEventType.TEMPLATING ->  2.1e-05 seconds
        |_CBEventType.LLM ->  5.761943 seconds
        |_CBEventType.TEMPLATING ->  2e-05 seconds
        |_CBEventType.LLM ->  3.526165 seconds
        |_CBEventType.TEMPLATING ->  2.2e-05 seconds
        |_CBEventType.LLM ->  0.0 seconds
    |_CBEventType.LLM ->  5.47279 seconds
    |_CBEventType.LLM ->  5.761676 seconds
    |_CBEventType.LLM ->  3.525903 seconds
    |_CBEventType.LLM ->  0.0 seconds
**********
Token usage for query: what is xyz?
 Embedding Tokens:  6 
 LLM Prompt Tokens:  4203 
 LLM Completion Tokens:  739 
 Total LLM Token Count:  4942

Based on the additional context provided, I would make the following minor updates to the existing answer:

xyz outlines 

Here are my prompts for text_qa and refine:

# Prompts
qa_prompt = (
    "You are a Q&A assistant, "
    "you are responsible for answering the query with given context. "
    "Context information is below.\n"
    "--------------------\n"
    "{context_str}\n"
    "--------------------\n"
    "Given the context information not prior knowledge, "
    "answer the query in a concise and clear way. "
    "If the answer is not in the context, return 'I don't know'. \n"
    "Query: {query_str}\n"
    "Answer: "
)

refine_prompt = (
    "The original query is as following: "
    "{query_str}\n"
    "We have provided an existing answer: "
    "{existing_answer}\n"
    "We have the opportunity to refine "
    "the existing answer with some more context below. "
    "--------------------\n"
    "{context_msg}\n"
    "--------------------\n"
    "Given the context information not prior knowledge, "
    "update the existing answer only if necessary.\n"
)
dosubot[bot] commented 9 months ago

🤖

Hey @jc1518, great to see you diving deep again! What's cooking this time?

The behavior you're observing is due to the design of the Refine class in the LlamaIndex framework. Specifically, the get_response method in this class is designed to iterate over the provided text_chunks and for each chunk, it either generates a response or refines the previous response.

Here's a simplified version of the relevant code:

def get_response(
        self,
        query_str: str,
        text_chunks: Sequence[str],
        prev_response: Optional[RESPONSE_TEXT_TYPE] = None,
        **response_kwargs: Any,
    ) -> RESPONSE_TEXT_TYPE:
        response: Optional[RESPONSE_TEXT_TYPE] = None
        for text_chunk in text_chunks:
            if prev_response is None:
                response = self._give_response_single(
                    query_str, text_chunk, **response_kwargs
                )
            else:
                response = self._refine_response_single(
                    prev_response, query_str, text_chunk, **response_kwargs
                )
            prev_response = response
        return response

In this code, _give_response_single and _refine_response_single are methods that make LLM calls. So, if you have top_k=2, it means that you have two text_chunks to process. For each chunk, the get_response method will first generate a response (one LLM call) and then refine it (another LLM call). Therefore, for top_k=2, you end up with four LLM calls in total.

The top_k parameter in your test is controlling the number of text_chunks that are processed, and each chunk results in two LLM calls due to the design of the get_response method. This is why you're seeing four LLM calls and a refinement process, even though the context window is 100K and the response mode is compact.

The top_k parameter in the LlamaIndex framework, specifically in the BM25Retriever class, influences the number of text_chunks or nodes processed by determining the number of top scoring nodes to return after the retrieval process.

In the _retrieve method of the BM25Retriever class, the top_k parameter is used to slice the sorted list of scored nodes. After the nodes are scored and sorted in descending order (with the highest scores first), the top_k parameter is used to select the top scoring nodes to return.

Here is the relevant code:

# Sort and get top_k nodes, score range => 0..1, closer to 1 means more relevant
nodes = sorted(scored_nodes, key=lambda x: x.score or 0.0, reverse=True)
return nodes[: self._similarity_top_k]

In this code, self._similarity_top_k is the top_k parameter. The line return nodes[: self._similarity_top_k] returns the top k nodes based on their scores. Therefore, the top_k parameter directly influences the number of text_chunks or nodes processed by limiting the number of nodes returned after the retrieval process.

You can find more details in the source code of the LlamaIndex repository:

I hope this helps clarify the behavior you're observing. If you have any more questions, feel free to ask.

Sources

#### About Dosu This response is meant to be useful and save you time. It is **not meant to be a precise solution**, but rather a starting point for your own research. Help me be more useful! Please leave a 👍 if this is helpful and 👎 if it is irrelevant. If you want to continue the conversation, start your reply with @dosu-bot.
jc1518 commented 9 months ago

@dosu-bot, for compact mode, isn't supposed to stuff as much retrieved chunks as possible into one LLM call wherever within the context window?

logan-markewich commented 9 months ago

@jc1518 can you share how you setup the LLM and query engine?

jc1518 commented 9 months ago

Hi @logan-markewich , thanks for getting back to me:

I created a class of using Bedrock as service context

"""Bedrock Service Context"""
import logging

import boto3
import tiktoken
from botocore.config import Config
from langchain_community.embeddings import BedrockEmbeddings
from langchain_community.llms import Bedrock
from llama_index import PromptHelper, ServiceContext
from llama_index.prompts import PromptTemplate
from llama_index.callbacks import (
    CallbackManager,
    LlamaDebugHandler,
    TokenCountingHandler,
)

from lib import settings

boto3.set_stream_logger(name="botocore.credentials", level=logging.ERROR)

class BedrockServiceContext:
    """Bedrock Service Context"""

    def __init__(self, **kwargs):
        self.region = kwargs.get("region", settings.bedrock_region)
        self.text_model = kwargs.get("text_model", settings.bedrock_text_model)
        self.embed_model = kwargs.get("embed_model", settings.bedrock_embed_model)
        self.config = Config(
            retries={"max_attempts": settings.max_attempts, "mode": "standard"}
        )
        self.chunk_size = kwargs.get("chunk_size", settings.chunk_size)
        self.streaming = kwargs.get("streaming", settings.streaming)
        self.verbose = kwargs.get("verbose", settings.verbose)
        self.print_trace_on_end = kwargs.get(
            "print_trace_on_end", settings.print_trace_on_end
        )
        self.context_window = kwargs.get("context_window", settings.context_window)
        self.num_output = kwargs.get("num_output", settings.num_output)
        self.chunk_overlap_ratio = kwargs.get(
            "chunk_overlap_ratio", settings.chunk_overlap_ratio
        )
        self.token_counter = kwargs.get(
            "token_counter",
            TokenCountingHandler(
                tokenizer=tiktoken.encoding_for_model("gpt-3.5-turbo").encode
            ),
        )
        self.callback_manager = CallbackManager(
            [
                LlamaDebugHandler(print_trace_on_end=self.print_trace_on_end),
                self.token_counter,
            ]
        )
        self.prompt_helper = PromptHelper(
            context_window=self.context_window,
            num_output=self.num_output,
            chunk_overlap_ratio=self.chunk_overlap_ratio,
            chunk_size_limit=self.chunk_size,
        )
        self.temperature = kwargs.get("temperature", settings.temperature)
        self.max_tokens_to_sample = kwargs.get(
            "max_token_to_sample", settings.max_tokens_to_sample
        )
        self.model_kwargs = kwargs.get(
            "model_kwargs", {"temperature": self.temperature}
        )
        if self.text_model.startswith("anthropic.claude"):
            self.model_kwargs = kwargs.get(
                "model_kwargs",
                {
                    "temperature": self.temperature,
                    "max_tokens_to_sample": self.max_tokens_to_sample,
                },
            )
        if self.text_model.startswith("meta.llama2"):
            self.model_kwargs = kwargs.get(
                "model_kwargs",
                {
                    "temperature": self.temperature,
                    "top_p": settings.top_p,
                },
            )

    def _bedrock_client(self):
        """Bedrock client"""
        session = boto3.Session(region_name=self.region)
        boto3_bedrock = session.client(
            service_name="bedrock-runtime", config=self.config
        )
        return boto3_bedrock

    def _llm(self):
        """Bedrock predictor llm"""
        llm = Bedrock(
            client=self._bedrock_client(),
            model_id=self.text_model,
            streaming=self.streaming,
            verbose=self.verbose,
            model_kwargs=self.model_kwargs,
        )
        return llm

    def _embedding(self):
        """Bedrock embedding llm"""
        embedding = BedrockEmbeddings(
            client=self._bedrock_client(), model_id=self.embed_model
        )
        return embedding

    def for_query_engine(self):
        """Bedrock powered predictor service context"""
        service_context = ServiceContext.from_defaults(
            llm=self._llm(),
            embed_model=self._embedding(),
            callback_manager=self.callback_manager,
            prompt_helper=self.prompt_helper,
        )
        return service_context

Another class for query engine, here is query engine that I used in the test.

    def simple_query_engine(
        self,
        index_type,
        streaming,
        similarity_top_k=default_similarity_top_k,
        response_mode=default_response_mode,
        query_transformation=False,
        text_qa_template=None,
        doc_id=None,
    ):
        """Simple query engine"""
        set_global_service_context(self.service_context)
        document = Document(self.llm_provider)
        index = document.load_index(index_type, doc_id)
        query_engine = index.as_query_engine(
            streaming=streaming,
            response_mode=response_mode,
            similarity_top_k=similarity_top_k,
            text_qa_template=text_qa_template,
            refine_template=refine_prompt_template,
            use_async=True,
        )
        if query_transformation:
            query_engine_tools = [
                QueryEngineTool(
                    query_engine=query_engine,
                    metadata=ToolMetadata(
                        name="Knowledge Base",
                        description="Use this tool for all queries",
                    ),
                )
            ]
            query_engine = SubQuestionQueryEngine.from_defaults(
                query_engine_tools=query_engine_tools, use_async=True
            )
        return query_engine
logan-markewich commented 9 months ago

@jc1518 the bedrock llm uses some hardcoded dictionary to look up the context window. My guess is maybe that loopkup is failing for whatever reason here

If you are using anthropic, try hardcoding the context window to 100k

llm = Bedrock(..., context_size=100000)

jc1518 commented 9 months ago

Thanks @logan-markewich , I have tried that, but context_size is not a valid parameter for Bedrock class. Also I did a search in langchain repo, could not find anything relevant.

I will keep digging to see if I can find anything. And I did another with refine mode which made four LLM calls as well, but the response is more verbose than compact mode.

Test with refine mode:

Query: what is xyz?
**********
Trace: index_construction
**********
**********
Trace: query
    |_CBEventType.QUERY ->  26.406014 seconds
      |_CBEventType.RETRIEVE ->  1.216888 seconds
        |_CBEventType.EMBEDDING ->  1.183715 seconds
      |_CBEventType.SYNTHESIZE ->  25.188966 seconds
        |_CBEventType.TEMPLATING ->  2.1e-05 seconds
        |_CBEventType.LLM ->  8.037965 seconds
        |_CBEventType.TEMPLATING ->  3.3e-05 seconds
        |_CBEventType.LLM ->  9.985102 seconds
        |_CBEventType.TEMPLATING ->  2.2e-05 seconds
        |_CBEventType.LLM ->  7.148235 seconds
        |_CBEventType.TEMPLATING ->  2.1e-05 seconds
        |_CBEventType.LLM ->  0.0 seconds
    |_CBEventType.LLM ->  8.037624 seconds
    |_CBEventType.LLM ->  9.979597 seconds
    |_CBEventType.LLM ->  7.147956 seconds
    |_CBEventType.LLM ->  0.0 seconds
**********
Token usage for query: what is xyz
 Embedding Tokens:  5 
 LLM Prompt Tokens:  4688 
 LLM Completion Tokens:  1442 
 Total LLM Token Count:  6130 

Based on the additional context provided, I would make the following minor updates to the existing summary answer:

xyz outlines 
logan-markewich commented 9 months ago

@jc1518 you are using langchain bedrock? Ohhhh that explains a lot.

We also have bedrock support, it's what I was referencing

https://docs.llamaindex.ai/en/stable/examples/llm/bedrock.html

Alternatively, you can also specify context_window=100000 in the service context

jc1518 commented 9 months ago

Ah, right. I had the context_window in the PromptHelper. I need to adjust it dynamically.

jc1518 commented 9 months ago

Thanks again for your help @logan-markewich !!!