Closed Redna closed 3 months ago
Sorry for necromancing the thread, but I think there is probably a more graceful way to do this. The downside of making streaming_callback as runtime variable is that the streaming logic would no longer be included during serialization and the entire streamer instance needs to be re-instantiated upon every request, despite that the only difference is probably the request context or simply the request id (IMO the request context info should not be directly passed to the pipeline as it really has nothing to do with the pipeline).
I've come up with an alternative solution which I think might be more general (?). It does need to re-create the pipeline nor the streamer:
from starlette_context import context, plugins, request_cycle_context
generator = ChatGenerator(
generation_kwargs={"temperature": 0.4},
streaming_callback=stream_generator.on_llm_new_chunk, # custom implementation where it utilizes context to isolate streaming chunks between requests
)
# Initialize pipeline and add connections
pipeline = Pipeline()
pipeline.add_component("retriever", retriever)
pipeline.add_component("prompt_builder", prompt_builder)
pipeline.add_component("generator", generator)
pipeline.connect("retriever", "prompt_builder.documents")
pipeline.connect("prompt_builder", "generator")
def run_pipeline(query: str, context_dict):
with request_cycle_context(context_dict):
try:
res = pipeline.run(
io_handler.parse_input({"query": query}),
include_outputs_from=io_handler.get_output_components(),
)
except Exception as err:
stream_generator.on_pipeline_error(err)
finally:
stream_generator.on_pipeline_end()
return io_handler.parse_output(res)
@app.post("/run")
async def run(query: str):
# Pass context information to each run_pipeline call
context_dict = context.copy()
threading.Thread(target=run_pipeline, args=(query, context_dict)).start()
return fastapi.responses.StreamingResponse(
stream_generator.stream_data(), media_type="text/event-stream;charset=UTF-8"
)
Sorry for necromancing the thread, but I think there is probably a more graceful way to do this. The downside of making streaming_callback as runtime variable is that the streaming logic would no longer be included during serialization and the entire streamer instance needs to be re-instantiated upon every request, despite that the only difference is probably the request context or simply the request id (IMO the request context info should not be directly passed to the pipeline as it really has nothing to do with the pipeline).
I've come up with an alternative solution which I think might be more general (?). It does need to re-create the pipeline nor the streamer:
from starlette_context import context, plugins, request_cycle_context generator = ChatGenerator( generation_kwargs={"temperature": 0.4}, streaming_callback=stream_generator.on_llm_new_chunk, # custom implementation where it utilizes context to isolate streaming chunks between requests ) # Initialize pipeline and add connections pipeline = Pipeline() pipeline.add_component("retriever", retriever) pipeline.add_component("prompt_builder", prompt_builder) pipeline.add_component("generator", generator) pipeline.connect("retriever", "prompt_builder.documents") pipeline.connect("prompt_builder", "generator") def run_pipeline(query: str, context_dict): with request_cycle_context(context_dict): try: res = pipeline.run( io_handler.parse_input({"query": query}), include_outputs_from=io_handler.get_output_components(), ) except Exception as err: stream_generator.on_pipeline_error(err) finally: stream_generator.on_pipeline_end() return io_handler.parse_output(res) @app.post("/run") async def run(query: str): # Pass context information to each run_pipeline call context_dict = context.copy() threading.Thread(target=run_pipeline, args=(query, context_dict)).start() return fastapi.responses.StreamingResponse( stream_generator.stream_data(), media_type="text/event-stream;charset=UTF-8" )
Thanks for this. Could you share the whole working example including the imports, please?
Sorry for necromancing the thread, but I think there is probably a more graceful way to do this. The downside of making streaming_callback as runtime variable is that the streaming logic would no longer be included during serialization and the entire streamer instance needs to be re-instantiated upon every request, despite that the only difference is probably the request context or simply the request id (IMO the request context info should not be directly passed to the pipeline as it really has nothing to do with the pipeline). I've come up with an alternative solution which I think might be more general (?). It does need to re-create the pipeline nor the streamer:
from starlette_context import context, plugins, request_cycle_context generator = ChatGenerator( generation_kwargs={"temperature": 0.4}, streaming_callback=stream_generator.on_llm_new_chunk, # custom implementation where it utilizes context to isolate streaming chunks between requests ) # Initialize pipeline and add connections pipeline = Pipeline() pipeline.add_component("retriever", retriever) pipeline.add_component("prompt_builder", prompt_builder) pipeline.add_component("generator", generator) pipeline.connect("retriever", "prompt_builder.documents") pipeline.connect("prompt_builder", "generator") def run_pipeline(query: str, context_dict): with request_cycle_context(context_dict): try: res = pipeline.run( io_handler.parse_input({"query": query}), include_outputs_from=io_handler.get_output_components(), ) except Exception as err: stream_generator.on_pipeline_error(err) finally: stream_generator.on_pipeline_end() return io_handler.parse_output(res) @app.post("/run") async def run(query: str): # Pass context information to each run_pipeline call context_dict = context.copy() threading.Thread(target=run_pipeline, args=(query, context_dict)).start() return fastapi.responses.StreamingResponse( stream_generator.stream_data(), media_type="text/event-stream;charset=UTF-8" )
Thanks for this. Could you share the whole working example including the imports, please?
I did a lot of customization in my use case so it is a little hard for me to get an example that can work out of the box. But here is a more complete version of it:
import asyncio
import json
import logging
import os
import threading
import fastapi
import uvicorn
from haystack import Pipeline
from haystack.dataclasses import StreamingChunk
from starlette.middleware import Middleware
from starlette_context import context, plugins, request_cycle_context
from starlette_context.middleware import RawContextMiddleware
class StreamingChunkGenerator:
"""
A Callback that handles streaming content from model services and provides a generator to stream the data
to an API endpoint. Each request is identified by the request ID and the streaming data is stored in a queue
for each request.
"""
def __init__(self):
self._streaming_data = {}
def on_pipeline_error(self, err: Exception):
request_id = context.data.get("X-Request-ID")
if request_id not in self._streaming_data:
self._streaming_data[request_id] = asyncio.Queue()
self._streaming_data[request_id].put_nowait(
ServerSentEvent(event="error", data=json.dumps({"error": {"message": str(err)}}, ensure_ascii=False))
)
def on_model_service_new_chunk(self, chunk: StreamingChunk) -> None:
"""
Callback to handle a new chunk of streaming data from the model service. This is only available when streaming
is enabled.
:param chunk: The new chunk of streaming data.
"""
request_id = context.data.get("X-Request-ID")
if request_id not in self._streaming_data:
self._streaming_data[request_id] = asyncio.Queue()
self._streaming_data[request_id].put_nowait(
ServerSentEvent(data=json.dumps({"content": chunk.content, "meta": chunk.meta}, ensure_ascii=False))
)
def on_pipeline_end(self):
"""
Callback to handle the end of the streaming data from the model service.
"""
request_id = context.data.get("X-Request-ID")
if request_id not in self._streaming_data:
self._streaming_data[request_id] = asyncio.Queue()
self._streaming_data[request_id].put_nowait(None)
async def stream_data(self):
"""
Stream the data from the model service.
"""
request_id = context.data.get("X-Request-ID")
while True:
if request_id not in self._streaming_data: # wait for the first chunk to arrive
continue
sse = await self._streaming_data[request_id].get()
if sse is None:
break
# TODO: Add a utility function to yield ServerSentEvent objects
if sse.event:
yield f"event:{sse.event}\n"
yield f"data:{sse.data}\n\n"
del self._streaming_data[request_id] # delete the entry after streaming to prevent memory leak
app = fastapi.FastAPI(middleware=[Middleware(RawContextMiddleware, plugins=(plugins.RequestIdPlugin(),))])
stream_generator = StreamingChunkGenerator()
logging.basicConfig(format="%(asctime)s %(levelname)s:%(message)s", level=logging.DEBUG)
def create_pipeline():
...
generator = ChatGenerator(
model="gpt-4o",
generation_kwargs={"temperature": 0.4},
stream=True,
streaming_callbacks=[stream_generator.on_model_service_new_chunk],
)
pipeline = Pipeline()
pipeline.add_component("generator", generator)
...
return pipeline
pipeline = create_pipeline()
def run_pipeline(query: str, context_dict):
with request_cycle_context(context_dict):
try:
res = pipeline.run(
{"query": query},
)
except Exception as err:
stream_generator.on_pipeline_error(err)
finally:
stream_generator.on_pipeline_end()
return res
@app.post("/run")
async def run(query: str):
# Pass context information to each run_pipeline call
context_dict = context.copy()
threading.Thread(target=run_pipeline, args=(query, context_dict)).start()
return fastapi.responses.StreamingResponse(
stream_generator.stream_data(), media_type="text/event-stream;charset=UTF-8"
)
if __name__ == "__main__":
uvicorn.run(app, host="localhost", port=8000)
Is your feature request related to a problem? Please describe. The current implementation of the OpenAIGenerator/OpenAIChatGenerator does not allow to pass a streaming_callback as a parameter in the pipeline.run function. This is causing issues when I want to create a FastApi endpoint with ServerSentEvents.
Currently I need to create a separate pipeline for each request coming in. And creating a pipeline for each request can be slow, because of loading the dependencies, warming_up_models or when using tracers, like Langfuse, causing other issues.
Describe the solution you'd like Would like to pass the streaming callback in the pipeline run method like it is done e.g. for the bedrock generator https://github.com/deepset-ai/haystack-core-integrations/blob/main/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py#L202
so in essence:
Describe alternatives you've considered Adding another dedicated parameter like streaming_callback. However, might be a breaking change then?