talkdai / dialog

RAG LLM Ops App for easy deployment and testing
https://dialog.talkd.ai
MIT License
369 stars 46 forks source link

Output containing messages that guided to prompt fallbacks #195

Open llemonS opened 4 months ago

llemonS commented 4 months ago

In order to increase the knowledge base for a specific subject over time, would be interesting if we could collect the messages that guided to the prompt fallbacks for later analysis. Maybe initially creating an output folder with a csv file containing those scenareos.

lgabs commented 4 months ago

Nice idea! Langsmith is a nice platform to debug each interaction with LLMs, but it's limited to langchain's chain calls, and indeed dialog's fallback implementation occurs before these calls, when no relevant documents are found from the retriever. Currently, this idea of saving fallbacks would probably occur optionally in the prompt generation step (e.g generate_prompt method of AbstractRAG class. The fallback cases could be saved in a csv intially, but maybe later it could be saved directly in the postgres database.

vmesel commented 4 months ago

What would the implementation look like in your POV @llemonS ?

llemonS commented 4 months ago

Well, considering the concept of knowledge base (.csv file) being used as an input and taking a look at the project structure, maybe we could create a path dialog/data/output to store the messages that guided to prompt fallbacks in order to enable a sort of feedback cycle for the own user adapt into the input .csv file later on.

lgabs commented 3 months ago

Recently, I've managed to use fallback in a more LCEL way, making it a runnable component as well, so it's trace goes to langsmith and we gain the information @llemonS talked about "for free", with all other langsmith benefits. It may help not only this issue but also making the chain fully LCEL adherent.

In the first print screenshot, the rag chain example received a question outside its designed context (see second screenshot), and a _chainrouter (python function with rules to route, in this case check the number of documents returned from the retriever) guided the chain to a FallbackChain, which imposes the fixed AI's fallback message you see in the first screenshot.

image

image

This can be achieved with a more complex LCEL chain, which I've adapted from the dialog's plugin I work with to this:

from typing import Dict, Any
from operator import itemgetter

from langchain_core.prompts import (
    ChatPromptTemplate,
)
from langchain_openai import ChatOpenAI
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import Runnable, RunnableLambda, RunnableParallel
from langchain_core.messages import AIMessage

# I've invented these imports based on my plugin, real implementation has to adapt here
from dialog_lib.settings import settings
from dialog_lib.vectorstore import get_retriever, combine_documents
from dialog_lib.models import Input

# Prompt
PROMPT_TEMPLATES = settings.PROJECT_CONFIG.get("prompt")
HEADER_PROMPT_TEMPLATE: str = PROMPT_TEMPLATES.get("header")
CONTEXT_PROMPT_TEMPLATE: str = PROMPT_TEMPLATES.get("context")
PROMPT_FALLBACK_MESSAGE: str = PROMPT_TEMPLATES.get(
    "fallback_not_found_relevant_contents"
)
fallback_message = ChatPromptTemplate.from_messages(
    [
        ("ai", PROMPT_FALLBACK_MESSAGE),
    ]
)
answer_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", HEADER_PROMPT_TEMPLATE),
        ("system", CONTEXT_PROMPT_TEMPLATE),
        ("human", "{question}"),
    ]
)

# LLM
MODEL_PARAMS = settings.PROJECT_CONFIG.get("model", {})
llm = ChatOpenAI(**MODEL_PARAMS, openai_api_key=settings.OPENAI_API_KEY)

# Build Chains

## Fallback Chain
def parse_fallback(ai_message: AIMessage) -> str:
    return ai_message.content

fallback_chain = (
    fallback_message | RunnableLambda(lambda x: x.messages[-1]) | parse_fallback
).with_config({"run_name": "FallbackChain"})

## Answer Chain
answer_chain = (
    (
        RunnableParallel(
            {
                "context": itemgetter("relevant_docs")
                | RunnableLambda(combine_documents),
                "question": itemgetter("question"),
            }
        ).with_config({"run_name": "QuestionWithContext"})
        | answer_prompt.with_config({"run_name": "AnswerPrompt"})
        | llm
        | StrOutputParser()
    )
).with_config({"run_name": "AnswerChain"})

def chain_router(inputs: Dict[str, Any]) -> Runnable:
    """
    Route conversation. If no relevant docs are found, we answer with a fixed fallback template, otherwise go ahead with our LLM Chain.
    """
    if len(inputs["relevant_docs"]) == 0:
        return fallback_chain
    else:
        return answer_chain

## Retriever Chain
retriever = get_retriever().with_config({"run_name": "Retriever"})
retriever_chain = itemgetter("question") | retriever

## Full Chain
full_chain = (
    (
        RunnableParallel(
            {"relevant_docs": retriever_chain, "question": lambda x: x["question"]}
        ).with_config({"run_name": "GetQuestionAndRelevantDocs"})
        | RunnableLambda(chain_router)
    )
    .with_types(input_type=Input)
    .with_config({"run_name": "FullChain"})
)