Chainlit / chainlit

Build Conversational AI in minutes ⚡️
https://docs.chainlit.io
Apache License 2.0
6.74k stars 874 forks source link

Human Feedback buttons dissappear when message is sent/updated within cl.Step context manager #1202

Open GillesJ opened 1 month ago

GillesJ commented 1 month ago

Describe the bug I am running a RAG assistant app with chainlit + llama_index and using the default LiteralAI setup for Human Feedback. The Feedback UI thumbs appeared correctly until I added cl.Step context managers, but since cleaning up message handling and adding Step context managers the feedback buttons have disappeared from the UI.

Any ideas why creating and updating a message within a Step context call (async with cl.Step(name="RAG", type="tool"): would remove the Feedback thumbs UI? The feedback UI correctly appears on the last message without using Step context managers.

To Reproduce Steps to reproduce the behavior:

import os
import chainlit as cl
from dotenv import load_dotenv
from rag import get_chat_store, get_cot_engine, patch_chat_completion
from response_postprocess import create_response_with_cited_sources_llm
from loguru import logger
from llama_index.core import Settings
from llama_index.core.callbacks import CallbackManager
from sources.rag_app.logging import data_logger
import time
from typing import Any
from prompts import rephrase_query_response
from literalai import LiteralClient
# llama_index.core.set_global_handler("simple") # Uncomment this if you want to debug full llm prompts.

load_dotenv()
patch_chat_completion()

lai = LiteralClient(api_key=os.environ["LITERAL_API_KEY"])
logger.debug(f"LiteralAI client loaded.")

NAME_SYSTEM = "Gizmo" # Assistant name.

tracer = None
if os.getenv("WEBSITE_SITE_NAME"):
    from azure.monitor.opentelemetry import configure_azure_monitor
    from opentelemetry import trace

    configure_azure_monitor()
    tracer = trace.get_tracer(__name__)

@cl.on_chat_start
async def factory():
    logger.info("Chat session started. Setting up environment.")

    # Set session counters for monitoring.
    cl.user_session.set("user_message_count", 0)
    cl.user_session.set("system_message_count", 0)
    cl.user_session.set("citation_message_count", 0)
    cl.user_session.set("no_source_count", 0)

    # Set the callback
    Settings.callback_manager = CallbackManager([cl.LlamaIndexCallbackHandler()])

    chat_store_memory = get_chat_store()
    rag_engine = get_cot_engine(memory=chat_store_memory)
    cl.user_session.set("rag_engine", rag_engine)

    await cl.Message(author=NAME_SYSTEM, content="Hallo! Wat is je vraag?").send()

async def get_session_meta(meta_keys: list | None = None) -> dict[str, Any]:
    if meta_keys is None:  # defaults
        meta_keys = [
            "company_id",
            "id",
            "user_message_count",
            "system_message_count",
            "citation_message_count",
            "no_source_count",
        ]
    data = {k: cl.user_session.get(k) for k in meta_keys}
    return {k: v for k, v in data.items() if v is not None}  # Clean-up none values

async def increment_counter(counter_name: str) -> None:
    counter = cl.user_session.get(counter_name, 0)
    counter += 1
    cl.user_session.set(counter_name, counter)

@cl.on_message  # this function will be called every time a user inputs a message in the UI
async def main(message: cl.Message):
    async def handler():
        start_time = time.perf_counter()
        await increment_counter("user_message_count")
        data_logger(
            "User message received", message=message.content, **await get_session_meta()
        )
        logger.info(f"Received message: {message.content}")
        rag_engine = cl.user_session.get("rag_engine")
        if not rag_engine:
            logger.warning(
                "RAG engine not found in session. Message processing aborted."
            )
            return
        async with cl.Step(name=f"`{NAME_SYSTEM} zoek en antwoord`", type="tool") as step:
            response = await cl.make_async(rag_engine.stream_chat)(message.content)
            logger.debug("RAG response processed.")

            # If there are sources nodes in the response, send the response with citations and sources.
            response_message = cl.Message(author=NAME_SYSTEM, content="")
            if response.source_nodes:
                for token in response.response_gen:
                    await response_message.stream_token(token=token)
                await response_message.send()

                initial_response_time = time.perf_counter() - start_time
                await increment_counter("system_message_count")
                data_logger(
                    "Initial system response sent to user.",
                    message=response_message.content,
                    response_time=initial_response_time,
                    **await get_session_meta(),
                )
                logger.debug("Initial response message sent.")

                # Add citations and sources.
                async with cl.Step(
                    name="`Bronnen controleren en citeren`", type="tool"
                ) as cite_step:
                    (
                        response_with_cite,
                        sources_text,
                        source_elements,
                    ) = await create_response_with_cited_sources_llm(response)
                    response_with_cite_fmt = (
                        response_with_cite + "\n\n---\n" + sources_text
                    )
                    response_message.content = response_with_cite_fmt
                    response_message.elements = source_elements

                    await response_message.send()
                    cite_step.output = response_with_cite_fmt
                    await cite_step.update()
                step.output = response_with_cite_fmt
                await step.update()

                cited_response_time = time.perf_counter() - start_time
                await increment_counter("citation_message_count")
                data_logger(
                    "Cited system response sent to user.",
                    message=response_message.content,
                    response_time=cited_response_time,
                    **await get_session_meta(),
                )

            else:
                response_message.content = rephrase_query_response
                await response_message.send()
                step.output = response_message.content
                await step.update()

                no_source_response_time = time.perf_counter() - start_time
                await increment_counter("no_source_count")
                data_logger(
                    "No sources found.",
                    message=response_message.content,
                    response_time=no_source_response_time,
                    **await get_session_meta(),
                )

            logger.debug("Final message with citations and sources sent.")

    if tracer:
        with tracer.start_as_current_span("main"):
            await handler()
    else:
        await handler()

Expected behavior Human Feedback UI appears below latest system message.

Screenshots image

Desktop (please complete the following information):

python = ">=3.11,<3.13" chainlit = "1.1.400" llama-index = "^0.10.58"

oshoma commented 3 weeks ago

Confirming, in my app, which also sends messages within cl.Step blocks: