ajndkr / lanarky

The web framework for building LLM microservices
https://lanarky.ajndkr.com/
MIT License
976 stars 74 forks source link

How do you return source_documents using ConversationalRetrievalChain? #109

Closed auxon closed 1 year ago

auxon commented 1 year ago

Here's my code:

from enum import StrEnum
import os
from fastapi import FastAPI
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from dotenv import load_dotenv
from lanarky import LangchainRouter, StreamingResponse
from langchain.chains import LLMChain, ConversationalRetrievalChain
from langchain.embeddings import GPT4AllEmbeddings
from langchain.vectorstores import Chroma
from langchain.document_loaders import DirectoryLoader
from langchain.document_loaders import UnstructuredHTMLLoader
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.memory import ConversationSummaryBufferMemory
from langchain.chains.question_answering import load_qa_chain
from langchain.prompts.chat import (
    ChatPromptTemplate,
    PromptTemplate)
from langchain.memory.chat_message_histories.in_memory import (
    ChatMessageHistory)
from pydantic import BaseModel, constr
from AGPT4All import AGPT4All

load_dotenv()  # load environment variables from .env file
model_path = os.environ.get("MODEL_PATH")
titleDownloads = os.environ.get("TITLE_DOWNLOADS")
print(f"TITLE_DOWNLOADS={titleDownloads}")
print(f"MODEL_PATH={model_path}")

app = FastAPI(title="AI", version="0.0.1",
              description="AI Service")

app.mount("/static", StaticFiles(directory="static"), name="static")
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["GET", "POST"],
    allow_headers=["*"],
)

callbacks = [AsyncIteratorCallbackHandler(), StreamingStdOutCallbackHandler()]

llm = AGPT4All(model=model_path, callbacks=callbacks, verbose=True,
               max_tokens=4096, n_predict=4096, streaming=True)

embeddings = GPT4AllEmbeddings()

if os.path.exists("./chroma_db") and os.path.isdir("./chroma_db"):
    vectorstore = Chroma(persist_directory="./chroma_db",
                         embedding_function=embeddings)
else:
    loader = DirectoryLoader(
        f'{titleDownloads}',
        loader_cls=UnstructuredHTMLLoader,
        recursive=True, glob="**/*.html", show_progress=True)
    documents = loader.load()
    print(f"Total docs: {len(documents)}")
    chunks = []
    splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=32)
    for chunk in splitter.split_documents(documents):
        chunks.append(chunk)
    texts = [doc.page_content for doc in chunks]
    metadatas = [doc.metadata for doc in chunks]
    vectorstore = Chroma.from_texts(texts=texts,
                                    embedding=embeddings,
                                    metadatas=metadatas,
                                    persist_directory="./chroma_db")

retriever = vectorstore.as_retriever()

B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"

class Role(StrEnum):
    SYSTEM = "system"
    ASSISTANT = "assistant"
    USER = "user"

class Message(BaseModel):
    role: constr(
        regex=f"^({Role.ASSISTANT}|{Role.USER}|{Role.SYSTEM})$")  # NOQA
    content: str

class ChatRequest(BaseModel):
    model: str
    messages: list[Message]
    max_tokens: int
    temperature: float

def create_chain(messages: list[Message]):

    condense_template_str = (
        "You are an expert at summarizing chat histories and questions. "
        "Given the following Chat History and a Follow Up Question, "
        "rephrase the follow up question to be a new Standalone Question. \n "
        "Chat History: \n"
        "{chat_history} \n"
        "Follow Up Question: {question} \n"
        "Standalone question:")
    condense_template = PromptTemplate.from_template(
        f"{B_INST} {B_SYS}{condense_template_str.strip()}{E_SYS} {E_INST}")

    prompt_template_str = (
        "You are a helpful AI. "
        "Use the following Context and Chat History to answer the "
        "question at the end with a helpful and detailed answer. "
        "If you don't know the answer, just say "
        "'I don't know'; don't try to make up an answer. \n"
        "Context: {context} \n"
        "Chat History: {chat_history} \n"
        "Question: {question} \n"
        "Helpful Answer:")
    qa_prompt_template = ChatPromptTemplate.from_template(
        f"{B_INST} {B_SYS}{prompt_template_str.strip()}{E_SYS} {E_INST}")

    chat_memory = ChatMessageHistory()
    for message in messages:
        if message.role == Role.USER:
            chat_memory.add_user_message(message.content)
        elif message.role == Role.ASSISTANT:
            chat_memory.add_ai_message(message.content)

    memory = ConversationSummaryBufferMemory(
        llm=llm,
        chat_memory=chat_memory,
        memory_key="chat_history",
        input_key="question",
        return_messages=True)

    question_generator = LLMChain(llm=llm, prompt=condense_template,
                                  memory=memory, verbose=True)

    doc_chain = load_qa_chain(llm=llm, chain_type="stuff",
                              prompt=qa_prompt_template,  verbose=True)

    return ConversationalRetrievalChain(
            combine_docs_chain=doc_chain,
            memory=memory,
            retriever=retriever,
            question_generator=question_generator,
            return_generated_question=True,
            return_source_documents=True,
            output_key="answer",
            verbose=True)

router = LangchainRouter(
    streaming_mode=1,
    # llm_cache_mode=3,  # GPTCache
)

@router.post(
    "/chat",
    summary="AI Chat",
    description="Chat with AI Service",
)
def chat(request: ChatRequest):
    chain = create_chain(
        # model=request.model,
        messages=request.messages[:-1],
        # max_tokens=request.max_tokens,
        # temperature=request.temperature
        )
    return StreamingResponse.from_chain(
        chain, request.messages[-1].content, as_json=False)

app.include_router(router, tags=["chat"])

Everything is working except I want the source documents and metadata returned. If I set as_json=True in StreamingResponse.from_chain it returns a ton of "token"s but it also includes the sources and metadata, but it's an unexpected response format.

Originally posted by @auxon in https://github.com/ajndkr/lanarky/discussions/108

auxon commented 1 year ago

Fixed. Find the solution in the discussion: https://github.com/ajndkr/lanarky/discussions/108#discussion-5501534