Chainlit / chainlit

Build Conversational AI in minutes ⚡️
https://docs.chainlit.io
Apache License 2.0
6.87k stars 904 forks source link

on_chat_start thread issue #581

Open mangled-data opened 10 months ago

mangled-data commented 10 months ago

I had refactored the code in https://github.com/Chainlit/docs/blob/aa50e881dde660ee7393b0b80427e6c88524a3a5/examples/qa.mdx#L53

I am missing something on the flow. I thought on_chat_start would be invoked once per run, but it seems to be invoked periodically. Because I read embeddings from cache, it gets repeated over and over from what I could tell. It feels like a simpler sync implementation might be helpful because it is not clear who gets called when easily (and I also wonder race conditions may get to be harder to debug)

@cl.on_chat_start
async def on_chat_start():
    files = None

    if False:
        # Wait for the user to upload a file
        while files == None:
            files = await cl.AskFileMessage(
                content="Please upload a text file to begin!",
                accept=["text/plain"],
                max_size_mb=20,
                timeout=180,
            ).send()

        file = files[0]

        msg = cl.Message(
            content=f"Processing `{file.name}`...", disable_human_feedback=True
        )
        await msg.send()

        # Decode the file
        text = file.content.decode("utf-8")

        # Split the text into chunks
        texts = text_splitter.split_text(text)

        # Create a metadata for each chunk
        metadatas = [{"source": f"{i}-pl"} for i in range(len(texts))]

    # Create a Chroma vector store
    embeddings = OpenAIEmbeddings()
    if False:
        docsearch = await cl.make_async(Chroma.from_texts)(
            texts, embeddings, metadatas=metadatas
        )
    else:
        msg = cl.Message(
            content=f"Processing Cache...", disable_human_feedback=True
        )
        await msg.send()

        embedder = EmbeddingsCache(".embeddings")
        docsearch = embedder.get_db()

    message_history = ChatMessageHistory()

    memory = ConversationBufferMemory(
        memory_key="chat_history",
        output_key="answer",
        chat_memory=message_history,
        return_messages=True,
    )

    model_name = 'gpt-3.5-turbo'
    model_name = 'gpt-3.5-turbo-16k'

    retriever = docsearch.as_retriever(search_kwargs={"k": 5})

    # Create a chain that uses the Chroma vector store
    chain = ConversationalRetrievalChain.from_llm(
        ChatOpenAI(model_name=model_name, temperature=0, max_tokens=4000, streaming=True),
        chain_type="stuff",
        retriever=retriever,
        memory=memory,
        return_source_documents=True,
    )

    # Let the user know that the system is ready
    #msg.content = f"Processing `{file.name}` done. You can now ask questions!"
    #await msg.update()

    cl.user_session.set("chain", chain)

@cl.on_message
async def main(message: cl.Message):
    chain = cl.user_session.get("chain")  # type: ConversationalRetrievalChain
    cb = cl.AsyncLangchainCallbackHandler()

    res = await chain.acall(message.content, callbacks=[cb])
    answer = res["answer"]
    source_documents = res["source_documents"]  # type: List[Document]

    text_elements = []  # type: List[cl.Text]

    if source_documents:
        for source_idx, source_doc in enumerate(source_documents):
            source_name = f"source_{source_idx}"
            # Create the text element referenced in the message
            text_elements.append(
                cl.Text(content=source_doc.page_content, name=source_name)
            )
        source_names = [text_el.name for text_el in text_elements]

        if source_names:
            answer += f"\nSources: {', '.join(source_names)}"
        else:
            answer += "\nNo sources found"

    await cl.Message(content=answer, elements=text_elements).send()
willydouhard commented 10 months ago

on_chat_start is called once per user. Every time a user sends a message, on_message is called.

LeoThuanYen commented 5 months ago

I use the Askfile when on-message send, then you can check if use send file or not with the message, also check again if chain is set in session during the on-message so you can get the chain. on resume will need to check if chain field existed in thread.metadata and then you can get the vectordb from vector db store( suggest you named collection is thread id), something like this:

 if  message.elements or cl.user_session.get("chain") is not None:
        files = message.elements
        if files:
            await process_file(files)# process file and set chain in session 
        chain = cl.user_session.get("chain")  # type: ConversationalRetrievalChain
        if chain is None:
            await cl.Message(content="Sorry, I am unable to process your request", disable_feedback=True).send()
            return