deepset-ai / haystack

AI orchestration framework to build customizable, production-ready LLM applications. Connect components (models, vector DBs, file converters) to pipelines or agents that can interact with your data. With advanced retrieval methods, it's best suited for building RAG, question answering, semantic search or conversational agent chatbots.
https://haystack.deepset.ai
Apache License 2.0
17.75k stars 1.92k forks source link

OpenAIGenerator and OpenAIChatGenerator streaming_callback as kwargs #7836

Closed Redna closed 3 months ago

Redna commented 5 months ago

Is your feature request related to a problem? Please describe. The current implementation of the OpenAIGenerator/OpenAIChatGenerator does not allow to pass a streaming_callback as a parameter in the pipeline.run function. This is causing issues when I want to create a FastApi endpoint with ServerSentEvents.

Currently I need to create a separate pipeline for each request coming in. And creating a pipeline for each request can be slow, because of loading the dependencies, warming_up_models or when using tracers, like Langfuse, causing other issues.

@app.post("/chat")
async def handle_request(request: ChatRequest) -> StreamingResponse:
    loop = asyncio.get_running_loop()
    streamer = TextIteratorStreamer() # custom implementation

    pipe = Pipeline()
    pipe.add_component("retriever", InMemoryBM25Retriever(document_store=docstore))
    pipe.add_component("prompt_builder", PromptBuilder(template=template))
    pipe.add_component("llm", OpenAIGenerator(api_key=Secret.from_token("<your-api-key>"),
                                          api_base_url="http://localhost:30091/v1",
                                          streaming_callback=streamer.add))

    pipe.connect("retriever", "prompt_builder.documents")
    pipe.connect("prompt_builder", "llm")

    loop.run_in_executor(None, pipe.run, {
        "prompt_builder": {
            "query": query
        },
        "retriever": {
            "query": query
        }
    })

    return StreamingResponse(
        consume_streamer(streamer),
        media_type="text/event-stream",
    )

Describe the solution you'd like Would like to pass the streaming callback in the pipeline run method like it is done e.g. for the bedrock generator https://github.com/deepset-ai/haystack-core-integrations/blob/main/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py#L202

so in essence:


pipe = Pipeline()
pipe.add_component("retriever", InMemoryBM25Retriever(document_store=docstore))
pipe.add_component("prompt_builder", PromptBuilder(template=template))
pipe.add_component("llm", OpenAIGenerator(api_key=Secret.from_token("<your-api-key>"),
                                      api_base_url="http://localhost:30091/v1"))

@app.post("/chat")
async def handle_request(request: ChatRequest) -> StreamingResponse:
    loop = asyncio.get_running_loop()
    streamer = TextIteratorStreamer() # custom implementation

    pipe.connect("retriever", "prompt_builder.documents")
    pipe.connect("prompt_builder", "llm")

    loop.run_in_executor(None, pipe.run, {
        "prompt_builder": {
            "query": query
        },
        "retriever": {
            "query": query
        }, 
        "llm": {
            "generation_kwargs": {
                  "streaming_callback": streamer.add
            }
         }
    })

    return StreamingResponse(
        consume_streamer(streamer),
        media_type="text/event-stream",
    )

Describe alternatives you've considered Adding another dedicated parameter like streaming_callback. However, might be a breaking change then?

@component.output_types(replies=List[ChatMessage])    
def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None):
LastRemote commented 2 months ago

Sorry for necromancing the thread, but I think there is probably a more graceful way to do this. The downside of making streaming_callback as runtime variable is that the streaming logic would no longer be included during serialization and the entire streamer instance needs to be re-instantiated upon every request, despite that the only difference is probably the request context or simply the request id (IMO the request context info should not be directly passed to the pipeline as it really has nothing to do with the pipeline).

I've come up with an alternative solution which I think might be more general (?). It does need to re-create the pipeline nor the streamer:

from starlette_context import context, plugins, request_cycle_context

generator = ChatGenerator(
        generation_kwargs={"temperature": 0.4},
        streaming_callback=stream_generator.on_llm_new_chunk,  # custom implementation where it utilizes context to isolate streaming chunks between requests
)

# Initialize pipeline and add connections
pipeline = Pipeline()
pipeline.add_component("retriever", retriever)
pipeline.add_component("prompt_builder", prompt_builder)
pipeline.add_component("generator", generator)

pipeline.connect("retriever", "prompt_builder.documents")
pipeline.connect("prompt_builder", "generator")

def run_pipeline(query: str, context_dict):
    with request_cycle_context(context_dict):
        try:
            res = pipeline.run(
                io_handler.parse_input({"query": query}),
                include_outputs_from=io_handler.get_output_components(),
            )
        except Exception as err:
            stream_generator.on_pipeline_error(err)
        finally:
            stream_generator.on_pipeline_end()
    return io_handler.parse_output(res)

@app.post("/run")
async def run(query: str):
    # Pass context information to each run_pipeline call
    context_dict = context.copy()
    threading.Thread(target=run_pipeline, args=(query, context_dict)).start()
    return fastapi.responses.StreamingResponse(
        stream_generator.stream_data(), media_type="text/event-stream;charset=UTF-8"
    )
ilkersigirci commented 2 months ago

Sorry for necromancing the thread, but I think there is probably a more graceful way to do this. The downside of making streaming_callback as runtime variable is that the streaming logic would no longer be included during serialization and the entire streamer instance needs to be re-instantiated upon every request, despite that the only difference is probably the request context or simply the request id (IMO the request context info should not be directly passed to the pipeline as it really has nothing to do with the pipeline).

I've come up with an alternative solution which I think might be more general (?). It does need to re-create the pipeline nor the streamer:

from starlette_context import context, plugins, request_cycle_context

generator = ChatGenerator(
        generation_kwargs={"temperature": 0.4},
        streaming_callback=stream_generator.on_llm_new_chunk,  # custom implementation where it utilizes context to isolate streaming chunks between requests
)

# Initialize pipeline and add connections
pipeline = Pipeline()
pipeline.add_component("retriever", retriever)
pipeline.add_component("prompt_builder", prompt_builder)
pipeline.add_component("generator", generator)

pipeline.connect("retriever", "prompt_builder.documents")
pipeline.connect("prompt_builder", "generator")

def run_pipeline(query: str, context_dict):
    with request_cycle_context(context_dict):
        try:
            res = pipeline.run(
                io_handler.parse_input({"query": query}),
                include_outputs_from=io_handler.get_output_components(),
            )
        except Exception as err:
            stream_generator.on_pipeline_error(err)
        finally:
            stream_generator.on_pipeline_end()
    return io_handler.parse_output(res)

@app.post("/run")
async def run(query: str):
    # Pass context information to each run_pipeline call
    context_dict = context.copy()
    threading.Thread(target=run_pipeline, args=(query, context_dict)).start()
    return fastapi.responses.StreamingResponse(
        stream_generator.stream_data(), media_type="text/event-stream;charset=UTF-8"
    )

Thanks for this. Could you share the whole working example including the imports, please?

LastRemote commented 2 months ago

Sorry for necromancing the thread, but I think there is probably a more graceful way to do this. The downside of making streaming_callback as runtime variable is that the streaming logic would no longer be included during serialization and the entire streamer instance needs to be re-instantiated upon every request, despite that the only difference is probably the request context or simply the request id (IMO the request context info should not be directly passed to the pipeline as it really has nothing to do with the pipeline). I've come up with an alternative solution which I think might be more general (?). It does need to re-create the pipeline nor the streamer:

from starlette_context import context, plugins, request_cycle_context

generator = ChatGenerator(
        generation_kwargs={"temperature": 0.4},
        streaming_callback=stream_generator.on_llm_new_chunk,  # custom implementation where it utilizes context to isolate streaming chunks between requests
)

# Initialize pipeline and add connections
pipeline = Pipeline()
pipeline.add_component("retriever", retriever)
pipeline.add_component("prompt_builder", prompt_builder)
pipeline.add_component("generator", generator)

pipeline.connect("retriever", "prompt_builder.documents")
pipeline.connect("prompt_builder", "generator")

def run_pipeline(query: str, context_dict):
    with request_cycle_context(context_dict):
        try:
            res = pipeline.run(
                io_handler.parse_input({"query": query}),
                include_outputs_from=io_handler.get_output_components(),
            )
        except Exception as err:
            stream_generator.on_pipeline_error(err)
        finally:
            stream_generator.on_pipeline_end()
    return io_handler.parse_output(res)

@app.post("/run")
async def run(query: str):
    # Pass context information to each run_pipeline call
    context_dict = context.copy()
    threading.Thread(target=run_pipeline, args=(query, context_dict)).start()
    return fastapi.responses.StreamingResponse(
        stream_generator.stream_data(), media_type="text/event-stream;charset=UTF-8"
    )

Thanks for this. Could you share the whole working example including the imports, please?

I did a lot of customization in my use case so it is a little hard for me to get an example that can work out of the box. But here is a more complete version of it:

import asyncio
import json
import logging
import os
import threading

import fastapi
import uvicorn
from haystack import Pipeline
from haystack.dataclasses import StreamingChunk
from starlette.middleware import Middleware
from starlette_context import context, plugins, request_cycle_context
from starlette_context.middleware import RawContextMiddleware

class StreamingChunkGenerator:
    """
    A Callback that handles streaming content from model services and provides a generator to stream the data
    to an API endpoint. Each request is identified by the request ID and the streaming data is stored in a queue
    for each request.
    """

    def __init__(self):
        self._streaming_data = {}

    def on_pipeline_error(self, err: Exception):
        request_id = context.data.get("X-Request-ID")
        if request_id not in self._streaming_data:
            self._streaming_data[request_id] = asyncio.Queue()
        self._streaming_data[request_id].put_nowait(
            ServerSentEvent(event="error", data=json.dumps({"error": {"message": str(err)}}, ensure_ascii=False))
        )

    def on_model_service_new_chunk(self, chunk: StreamingChunk) -> None:
        """
        Callback to handle a new chunk of streaming data from the model service. This is only available when streaming
        is enabled.

        :param chunk: The new chunk of streaming data.
        """
        request_id = context.data.get("X-Request-ID")
        if request_id not in self._streaming_data:
            self._streaming_data[request_id] = asyncio.Queue()
        self._streaming_data[request_id].put_nowait(
            ServerSentEvent(data=json.dumps({"content": chunk.content, "meta": chunk.meta}, ensure_ascii=False))
        )

    def on_pipeline_end(self):
        """
        Callback to handle the end of the streaming data from the model service.
        """
        request_id = context.data.get("X-Request-ID")
        if request_id not in self._streaming_data:
            self._streaming_data[request_id] = asyncio.Queue()
        self._streaming_data[request_id].put_nowait(None)

    async def stream_data(self):
        """
        Stream the data from the model service.
        """
        request_id = context.data.get("X-Request-ID")
        while True:
            if request_id not in self._streaming_data:  # wait for the first chunk to arrive
                continue
            sse = await self._streaming_data[request_id].get()
            if sse is None:
                break
            # TODO: Add a utility function to yield ServerSentEvent objects
            if sse.event:
                yield f"event:{sse.event}\n"
            yield f"data:{sse.data}\n\n"
        del self._streaming_data[request_id]  # delete the entry after streaming to prevent memory leak

app = fastapi.FastAPI(middleware=[Middleware(RawContextMiddleware, plugins=(plugins.RequestIdPlugin(),))])
stream_generator = StreamingChunkGenerator()

logging.basicConfig(format="%(asctime)s %(levelname)s:%(message)s", level=logging.DEBUG)

def create_pipeline():
    ...
    generator = ChatGenerator(
        model="gpt-4o",
        generation_kwargs={"temperature": 0.4},
        stream=True,
        streaming_callbacks=[stream_generator.on_model_service_new_chunk],
    )

    pipeline = Pipeline()
    pipeline.add_component("generator", generator)
    ...
    return pipeline

pipeline = create_pipeline()

def run_pipeline(query: str, context_dict):
    with request_cycle_context(context_dict):
        try:
            res = pipeline.run(
                {"query": query},
            )
        except Exception as err:
            stream_generator.on_pipeline_error(err)
        finally:
            stream_generator.on_pipeline_end()
    return res

@app.post("/run")
async def run(query: str):
    # Pass context information to each run_pipeline call
    context_dict = context.copy()
    threading.Thread(target=run_pipeline, args=(query, context_dict)).start()
    return fastapi.responses.StreamingResponse(
        stream_generator.stream_data(), media_type="text/event-stream;charset=UTF-8"
    )

if __name__ == "__main__":
    uvicorn.run(app, host="localhost", port=8000)