langchain-ai / langchain

🦜🔗 Build context-aware reasoning applications
https://python.langchain.com
MIT License
91.43k stars 14.55k forks source link

Issue: Display metadata after streaming a response with FastAPI #5409

Closed zigax1 closed 11 months ago

zigax1 commented 1 year ago

Issue you'd like to raise.

I have sucessfully set up streaming in HTTP call with FastApi and OpenAI + ConversationalRetrievalChain

If I don't use streaming and just return the whole response, like I was doing previously, I also get metadata displayed with the answer. If I enable streaming, i get displayed only the answer and the '%' at the end of response.

Like: ......dummytext.%

Code, which is responsible for streaming:

from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
import queue

class ThreadedGenerator:
    def __init__(self):
        self.queue = queue.Queue()

    def __iter__(self):
        return self

    def __next__(self):
        item = self.queue.get()
        if item is StopIteration: raise item
        return item

    def send(self, data):
        self.queue.put(data)

    def close(self):
        self.queue.put(StopIteration)

class ChainStreamHandler(StreamingStdOutCallbackHandler):
    def __init__(self, gen):
        super().__init__()
        self.gen = gen

    def on_llm_new_token(self, token: str, **kwargs):
        self.gen.send(token)

    def on_llm_new_token(self, token: str, **kwargs):
        self.gen.send(token)

Ask question function:

def askQuestion(self, generator, collection_id, question):
        try:
            collection_name = "collection-" + str(collection_id)
            self.llm = ChatOpenAI(model_name=self.model_name, temperature=self.temperature, openai_api_key=settings.OPENAI_API_KEY, streaming=True, verbose=VERBOSE, callback_manager=CallbackManager([ChainStreamHandler(generator)]))
            self.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True,  output_key='answer')

            self.chain = ConversationalRetrievalChain.from_llm(self.llm, chroma_Vectorstore.as_retriever(similarity_search_with_score=True),
                                                                return_source_documents=True,verbose=VERBOSE, 
                                                                memory=self.memory)

            result = self.chain({"question": question})

            res_dict = {
                "answer": result["answer"],
            }

            res_dict["source_documents"] = []

            for source in result["source_documents"]:
                res_dict["source_documents"].append({
                    "page_content": source.page_content,
                    "metadata":  source.metadata
                })

            return res_dict

        finally:
            generator.close()

And the API route itself

def stream(question, collection_id):
    generator = ThreadedGenerator()
    threading.Thread(target=thread_handler.askQuestion, args=(generator, collection_id, question)).start()
    return generator

@router.post("/collection/{collection_id}/ask_question")
async def ask_question(collection_id: str, request: Request):
    form_data = await request.form()
    question = form_data["question"]
    return StreamingResponse(stream(question, collection_id), media_type='text/event-stream')

In askQuestion function I am creating the res_dict object, which has answer and also source from metadata stored.

How can I also display the source after the answer is stopped being streamable? (I have source already in metadata)

Is for this a better way to make separate API call, or are there some other practices or,..?

Thanks for everyone for advice!

oneryalcin commented 1 year ago

Hi @zigax1 , In my implementation I've only used langchain for accessing opensearch, but I implemented streaming with OpenAI API, skipping langchain. Essentially helper function returns an async generator, it (streaming_request_chat) yields whatever I want to stream, so all yield statements are custom to my application. In the example below, I stream (yield) all data from OpenAI and once it is complete I yield the metadata. It's not Langchain but may give some insipration. (The code below is simplified, extracted from the running version)

async def streaming_request_chat(
        prompt: str,
        version: str,
        metadata: List[Document],
        refactored_question: Optional[str] = None,
        model: str = 'gpt-3.5-turbo',
        temperature: float = 0,
        max_tokens: int = 400,
) -> Generator:
    """Generator for each chunk received from OpenAI as response

    :param refactored_question: refactored question
    :rtype: object
    :param version: Data/Algorithm version
    :param max_tokens: Number of output tokens
    :param temperature: Model temperature
    :param prompt: User Prompt.
    :param model: OpenAI Model name
    :param metadata: Last metadata to append
    :return: generator object for streaming response from OpenAI
    """
    try:
        response = await openai.ChatCompletion.acreate(
            model=model,
            # engine='gpt-35-turbo',
            messages=[
                {'role': 'user', 'content': prompt}
            ],
            max_tokens=max_tokens,
            temperature=temperature,
            stream=True
        )
    except Exception as e:
        logger.exception(e)
        yield f"data: Error occurred while calling OpenAI API: {str(e)}\n\n"
        return

    try:
        source_documents = [pydantic_to_dict(SourceDocuments(
            page_content=doc.page_content,
            metadata=SourceDocumentMetadata(**doc.metadata))
        ) for doc in metadata]
    except PydanticValidationError as e:
        # logger.exception(e)
        yield f"data: Error occurred while verifying Source Data, Details: {str(e)}\n\n"
        return

    _metadata = {
        "version": version,
        "source_documents": source_documents
    }

    async for chunk in response:
        extract = chunk['choices'][0].get('delta')
        content = extract.get('content')
        if content:
            yield "event: response_chunk\n"
            # we use json.dumps to make sure chunk content doesn't break SSE formating (new lines & similar special
            # characters)
            yield "data: {} \n\n".format(json.dumps(content))

    yield "event: metadata\n"
    yield f"data: {json.dumps(_metadata, default=str)} \n\n"
    yield '[DONE] \n'

def response_iterator_factory(docs: List[Document], query: UserQuery) -> Generator:
    """Generator function to yield chunks of response from OpenAI

    :param docs: List of documents
    :param query: User query
    :return: Generator
    """

    # generate prompt based on documents and refactored_question
    prompt = gen_prompt(docs, query=query)

    # if model is chatgpt3 or gpt4, use streaming_request_chat
    return streaming_request_chat(
                prompt,
                version=SEARCH_AND_DATA_VERSION,
                metadata=docs,
                model=query.model.value,
                refactored_question=query.refactored_question,
                max_tokens=query.max_tokens,
                temperature=query.temperature,
            )

@app.on_event("startup")
async def startup_event():
    global docsearch

    logger.info("Loading vector store")
    # Vector store
    docsearch = OpenSearchVectorSearch(
        opensearch_url=OPENSEARCH_URL,
        index_name=OPENSEARCH_INDEX,
        embedding_function=OpenAIEmbeddings()
    )

@app.post('/streaming/ask')
async def streaming_ask(query: UserQuery) -> StreamingResponse:
    """Streaming API, this endpoint uses Server Side Events

    :param query: User question
    :return: Streaming Response chunks from OpenAI
    """

    question = query.question
    docs, question_embedding = docsearch.similarity_search(question, k=query.top_k, size=query.top_k)

    logger.info(f'Total Documents fetched: {len(docs)}')

    # call response_iterator_factory to get response iterator
    response_iterator = response_iterator_factory(docs=docs, query=query)
    return StreamingResponse(response_iterator, media_type="text/event-stream")
zigax1 commented 1 year ago

Thank you @oneryalcin for such a wonderful answer. That really helped me a lot and I sucessfully implemented the behaviour in my code.

Exact changes which I did:

CallBack handler and ThreadedGenerator --> I added new function, which will send the variable I want to the thread

from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
import queue

class ThreadedGenerator:
    def __init__(self):
        self.queue = queue.Queue()
        self.res_dict = None

    def __iter__(self):
        return self

    def __next__(self):
        item = self.queue.get()
        if item is StopIteration: raise item
        return item

    def send(self, data):
        self.queue.put(data)

    def set_res_dict(self, res_dict):
        self.res_dict = res_dict

    def close(self):
        self.queue.put(StopIteration)

class ChainStreamHandler(StreamingStdOutCallbackHandler):
    def __init__(self, gen):
        super().__init__()
        self.gen = gen

    def on_llm_new_token(self, token: str, **kwargs):
        self.gen.send(token)

    def on_llm_new_token(self, token: str, **kwargs):
        self.gen.send(token)

In my Ask question functioin, I just send the res_dict to the new function before returning the res_dict, like:

            result = self.chain({"question": question})
            res_dict = {
                "answer": result["answer"],
            }
            res_dict["source_documents"] = []

            for source in result["source_documents"]:
                res_dict["source_documents"].append({
                    "page_content": source.page_content,
                    "metadata":  source.metadata
                })

            generator.set_res_dict(res_dict)

            return res_dict

And the API route itself

def stream(question, collection_id, text_to_append):
    generator = ThreadedGenerator()
    threading.Thread(target=thread_handler.askQuestion, args=(generator, collection_id, question)).start()

    for content in generator:
        if content:
            yield content

    yield "\n | \n"
    res_dict = generator.res_dict
    if res_dict:
        yield json.dumps(res_dict)
    yield '[DONE] \n'

@router.post("/collection/{collection_id}/ask_question")
async def ask_question(collection_id: str, request: Request):
    form_data = await request.form()
    question = form_data["question"]
    def event_stream():
        for token in stream(question, collection_id, 'dummy text'):
            yield token
    return StreamingResponse(event_stream(), media_type='text/event-stream')

Great!

dosubot[bot] commented 11 months ago

Hi, @zigax1! I'm Dosu, and I'm helping the LangChain team manage their backlog. I wanted to let you know that we are marking this issue as stale.

From what I understand, the issue was about not being able to display metadata after streaming a response with FastAPI. @oneryalcin provided a helpful solution using a custom function and a threaded generator to send the desired variable to the thread. You successfully implemented the solution and shared the exact changes made to your code.

Before we close this issue, we wanted to check if it is still relevant to the latest version of the LangChain repository. If it is, please let us know by commenting on the issue. Otherwise, feel free to close the issue yourself or it will be automatically closed in 7 days.

Thank you for your contribution!