run-llama / llama_index

LlamaIndex is a data framework for your LLM applications
https://docs.llamaindex.ai
MIT License
37k stars 5.31k forks source link

[Question]: weired latency when i use the local:thenlper/gte-base as embedding to run query rewrite application. #12228

Closed lambda7xx closed 4 months ago

lambda7xx commented 8 months ago

Question Validation

Question

I run the query rewrite with the follow code.


"""## Setup

### Data
"""
import torch.cuda.nvtx as nvtx
from llama_index.core import  SimpleDirectoryReader
import sys

"""### LLM

This should run on a T4 instance on the free tier
"""

import torch
# from transformers import BitsAndBytesConfig

from llama_index.llms.huggingface import HuggingFaceLLM
from llama_index.core.prompts import PromptTemplate

from llama_index.core import Settings

from llama_index.core import VectorStoreIndex
from llama_index.core import  SimpleDirectoryReader
from llama_index.core.response.notebook_utils import display_response

import llama_index.core 
import time 

llm = HuggingFaceLLM(
    model_name="mistralai/Mistral-7B-Instruct-v0.1",
    tokenizer_name="mistralai/Mistral-7B-Instruct-v0.1",
    query_wrapper_prompt=PromptTemplate("<s>[INST] {query_str} [/INST] </s>\n"),
    context_window=3900,
    max_new_tokens=256,
    model_kwargs={"torch_dtype": torch.bfloat16},
    # tokenizer_kwargs={},
    generate_kwargs={"temperature": 0.2, "top_k": 5, "top_p": 0.95},
    device_map="auto",
)

Settings.llm = llm

# recordMemory("after embedding ")

llama2_paper_path = "../data/llama2_paper/llama2_paper.json"

import json
origin_queries = []
with open(llama2_paper_path) as f:
    data = json.load(f)
    i  = 0
    for key in data['examples']:
        if i < 20:
            origin_queries.append(key['query'])
            i = i + 1

#embeddings = ["local:BAAI/bge-small-en-v1.5", "local:thenlper/gte-base"]
embeddings = ["local:thenlper/gte-base"]

map_embeddings = dict()

for i in range(len(embeddings)):
    index = embeddings[i].find("/")
    if index == -1:
        index = embeddings[i].find(":")
    map_embeddings[embeddings[i]] = embeddings[i][index + 1:] 

column_name = ["embedding_name", "query_name", "load", "index", "query write", "query1","q1_len", "query2","q2_len", "query3", "q3_len","query4", "q4_len","query w/o rewrite", "q_wo_len","prompt on llm", "q_len"]

filename = "/home/ubuntu/uw-llama/query_rewrite/Mistral_llama2_paper_default_chunk_size_query_rewrite.csv"

load_time = 0 

index_time = 0

query_gen_str = """\
You are a helpful assistant that generates multiple search queries based on a \
single input query. Generate {num_queries} search queries, one on each line, \
related to the following input query:
Query: {query}
Queries:
"""
query_gen_prompt = PromptTemplate(query_gen_str)

def generate_queries(query: str, llm, num_queries: int = 4):
    print(f"*******start generate_queries and len(query):{getlen(query)}")
    start = time.time()
    response = llm.predict(
        query_gen_prompt, num_queries=num_queries, query=query
    )
    end = time.time()
    print(f"query write duration:{end - start}")
    query_rewrite_duration = end - start    
    # assume LLM proper put each query on a newline
    queries = response.split("\n")
    queries_str = "\n".join(queries)
    print(f"Generated queries:\n{queries_str}")
    print("*******end generate_queries\n\n")
    return queries, query_rewrite_duration

print(f"\n\n\nstart embed_model:{embeddings[i]}*************")

#warm up 
tmp_query = origin_queries[0]
for k in range(2):
        Settings.llm.complete(tmp_query)

def getlen(s):
    return len(s.split())

def main(j):
    for em in embeddings:
        Settings.embed_model = em
        print(f"\n\n\n*********embed_model:{em}*************")
        start = time.time()
        documents = SimpleDirectoryReader("../data/llama2_paper/").load_data() #data/llmama2_paper.json ./data/survery/llm_survery_paper.json
        end = time.time()
        load_time = end - start
        start = time.time()
        index = VectorStoreIndex.from_documents(documents=documents)
        end = time.time()
        index_time = end - start  
        query_engine = index.as_query_engine()
        all_time = [] 
        all_time.append(em)
        query = origin_queries[j]
        tmp_query = query 
        query_name = "query_" + str(j)
        all_time.append(query_name)
        print(f"query*******:{query}")

        all_time.append(load_time)
        print(f"load_time:{load_time}")

        all_time.append(index_time)
        print(f"index from time:{index_time}")

        # query_engine = index.as_query_engine()
        assert(query_engine is not None)

        queries,quert_rewrite_time = generate_queries(query,Settings.llm)
        all_time.append(quert_rewrite_time)
        #warm up
        # for query1 in queries:   
        #     response = query_engine.query(query1)

        print("*******start query rewrite")

        for query1 in queries:
            print(f"type(query1):{type(query1)} and type(query_engine):{type(query_engine)}")
            start = time.time()
            # print(f"type(query1):{type(query1)} and type(query_engine):{type(query_engine)}")
            response = query_engine.query(query1)
            end = time.time()
            query_time = end - start
            all_time.append(query_time)
            all_time.append(getlen(query1))
            print(f"query+generate  duration:{end - start}")
            print("****finish one query****\n\n")
            display_response(response)

        print("*******end query rewrite")

        start = time.time()
        response = query_engine.query(tmp_query)
        end = time.time()
        query_wo_rewrite_time = end - start
        all_time.append(query_wo_rewrite_time)
        print(f"query_wo_rewrite_time:{query_wo_rewrite_time}")
        all_time.append(getlen(tmp_query))
        start = time.time()
        assert(query == tmp_query)
        print(f"*******start prompt on llm and len(tmp_query):{getlen(tmp_query)}")
        Settings.llm.complete(tmp_query)
        end  = time.time()
        llm_prompt_duration = end - start
        print(f"query on llm directlly duration:{end - start}")
        all_time.append(llm_prompt_duration)
        all_time.append(len(tmp_query))
        print(f"len(all_time):{len(all_time)} and len(column_name):{len(column_name)}")
        assert(len(all_time) == len(column_name))
        print(f"finish embed_model:{embeddings[0]}*************\n\n\n\n")
if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser(description='argparse testing')
    parser.add_argument('--index', type=int,  default = 0,required=True,help="index of a query")
    args = parser.parse_args()
    index = args.index
    main(index)
dosubot[bot] commented 8 months ago

Hey there, @lambda7xx! Fancy seeing you here again. 🙂 I'm diving into your issue and will get back to you with some insights soon. Hang tight!

lambda7xx commented 8 months ago

I add timer in the /llama_index/llms/huggingface/base.py

    @llm_completion_callback()
    def complete(
        self, prompt: str, formatted: bool = False, **kwargs: Any
    ) -> CompletionResponse:
        """Completion endpoint."""
        full_prompt = prompt
        def getlen(s):
            return len(s.split())

        print(f"1 complete, len(prompt): {getlen(prompt)}")
        import time 
        start = time.time()
        if not formatted:
            if self.query_wrapper_prompt:
                full_prompt = self.query_wrapper_prompt.format(query_str=prompt)
            if self.system_prompt:
                full_prompt = f"{self.system_prompt} {full_prompt}"

        inputs = self._tokenizer(full_prompt, return_tensors="pt")
        inputs = inputs.to(self._model.device)
        end = time.time()
        print("2 complete, the time of tokenizer: ", end-start)
        # remove keys from the tokenizer if needed, to avoid HF errors
        for key in self.tokenizer_outputs_to_remove:
            if key in inputs:
                inputs.pop(key, None)
        start = time.time()
        tokens = self._model.generate(
            **inputs,
            max_new_tokens=self.max_new_tokens,
            stopping_criteria=self._stopping_criteria,
            **self.generate_kwargs,
        )#(lambda):insert torch.cuda.nvtx
        end = time.time()
        print("3 complete, the time of model.generate: ", end-start)
        completion_tokens = tokens[0][inputs["input_ids"].size(1) :]
        completion = self._tokenizer.decode(completion_tokens, skip_special_tokens=True)
        print("4 complete, len(completion): ", getlen(completion))
        print("************finish complete function************\n\n")

        return CompletionResponse(text=completion, raw={"model_output": tokens})
lambda7xx commented 8 months ago

I found if I run the query_engine.query(tmp_query), the latency is larger then one of new query.

part of log is below. my original query is query*******:Based on the abstract of "Llama 2: Open Foundation and Fine-Tuned Chat Models," what are the two primary objectives achieved in this work, and what is the range of parameters for the large language models developed?. For the original query, the prompt length is 36 and the new token lengths is 149 and its latency is 12.3 seconds, while for the new query, its prompt is 923 and new token lengths is 29 and the latency is 2.75 seconds. why is the new query's latency is smaller than the original query?

<IPython.core.display.Markdown object>
type(query1):<class 'str'> and type(query_engine):<class 'llama_index.core.query_engine.retriever_query_engine.RetrieverQueryEngine'>
*****retriever_query_engine.py tart _query*****
retriever_query_engine.py retriever time: 0.017117977142333984
1 Refine::get_response, get_response, the prev_response is:  None
2 Refine::get_response, get_response, the len(text_chunk) is: 877 and  len(query_str) is: 24
1 _give_response_single, the type(text_qa_template): <class 'llama_index.core.prompts.base.SelectorPromptTemplate'> and the type(text_chunks): <class 'list'>
1 Refine::_default_program_factory, self._structured_answer_filtering: False
2 _give_response_single, the responese:None and the self._streaming: False and type(cur_text_chunk): <class 'str'>
3 _give_response_single, the response is None and not self._streaming
DefaultRefineProgram::__call, the self._output_cls: None
DefaultRefineProgram::__call, call llm::predict

*******start llm predict
1 complete, len(prompt): 923
2 complete, the time of tokenizer:  0.0053899288177490234
3 complete, the time of model.generate:  2.725148916244507
4 complete, len(completion):  29
************finish complete function************

LLM::predict, the predict time 2.7312374114990234
Refine::get_response, self._give_response_single, the time is: 2.7343428134918213
retriever_query_engine.py  self._response_synthesizer.synthesize: 2.7391698360443115
query+generate  duration:2.7564148902893066
****finish one query****

*****retriever_query_engine.py tart _query*****
retriever_query_engine.py retriever time: 0.01692509651184082
1 Refine::get_response, get_response, the prev_response is:  None
2 Refine::get_response, get_response, the len(text_chunk) is: 877 and  len(query_str) is: 36
1 _give_response_single, the type(text_qa_template): <class 'llama_index.core.prompts.base.SelectorPromptTemplate'> and the type(text_chunks): <class 'list'>
1 Refine::_default_program_factory, self._structured_answer_filtering: False
2 _give_response_single, the responese:None and the self._streaming: False and type(cur_text_chunk): <class 'str'>
3 _give_response_single, the response is None and not self._streaming
DefaultRefineProgram::__call, the self._output_cls: None
DefaultRefineProgram::__call, call llm::predict

*******start llm predict
1 complete, len(prompt): 935
2 complete, the time of tokenizer:  0.00542449951171875
3 complete, the time of model.generate:  12.346087455749512
4 complete, len(completion):  172
************finish complete function************

LLM::predict, the predict time 12.352328300476074
Refine::get_response, self._give_response_single, the time is: 12.35537576675415
retriever_query_engine.py  self._response_synthesizer.synthesize: 12.360142707824707
query_wo_rewrite_time:12.377184391021729
dosubot[bot] commented 8 months ago

Based on the information you've provided, it seems like the latency issue you're experiencing might be due to several factors inherent in the design and execution flow of the LlamaIndex query processing mechanism. Here are some potential reasons:

  1. Recursive Query Processing: The query processing mechanism is recursive, especially for non-leaf nodes. This means that for each level of the tree, the system potentially makes multiple calls to either _query_with_selected_node or _query_level methods, which in turn may involve calls to external services (e.g., self._service_context.llm.predict). The recursion depth and the number of recursive calls can significantly increase the latency, particularly if the original query requires deeper traversal in the tree structure compared to a new query.

  2. External Service Calls: The methods _query_with_selected_node and _query_level involve calls to external services for prediction (self._service_context.llm.predict). The latency of these external calls can vary based on several factors such as network latency, the load on the external service, and the complexity of the query. If the original query involves more complex or numerous external service calls than a new query, this could lead to higher latency.

  3. Dynamic Template Selection and Text Splitting: The system dynamically selects query templates (query_template or query_template_multiple) and performs text splitting based on the number of child nodes and the branching factor. This dynamic behavior, including the construction of numbered node text and the handling of multiple chunks of text, can introduce additional computational overhead, especially if the original query results in a more complex branching structure.

  4. Node Selection and Number Extraction: The process of selecting nodes based on the responses from the external service (extract_numbers_given_response) and the subsequent logic to handle invalid responses or out-of-range numbers can add to the latency. This is more pronounced if the original query leads to responses that require additional handling for number extraction and validation.

  5. Verbose Logging and Debugging: The verbose logging and debugging statements, although useful for development and troubleshooting, can introduce I/O overhead, especially if the logging level is set to debug and the system is logging detailed information for each step of the query processing.

In summary, the higher latency for the original query compared to a new query could be due to the recursive nature of the query processing, the reliance on external service calls, the dynamic selection of query templates and text splitting, the complexity of node selection and number extraction, and the potential overhead from verbose logging. Each of these factors can contribute to the overall latency, and their impact can vary based on the specifics of the original query and the structure of the LlamaIndex.

You can find more details about the query processing mechanism in the select_leaf_retriever.py file.

To resolve this issue, you might want to consider optimizing the factors mentioned above. For instance, you could try to reduce the recursion depth or the number of external service calls, simplify the dynamic template selection and text splitting, streamline the node selection and number extraction, or adjust the logging level to reduce I/O overhead. However, please note that these optimizations should be done carefully as they could potentially affect the functionality and performance of other parts of the LlamaIndex.

I hope this helps! If you have any further questions or need more clarification, feel free to ask.

Sources

#### About Dosu This response is meant to be useful and save you time. It is **not meant to be a precise solution**, but rather a starting point for your own research. Help me be more useful! Please leave a 👍 if this is helpful and 👎 if it is irrelevant. If you want to continue the conversation, start your reply with @dosu-bot.
lambda7xx commented 8 months ago

does the llamaindex support batch query? currently, my code is for single query. if i want do batch query, how to do?