Chainlit / chainlit

Build Conversational AI in minutes ⚡️
https://docs.chainlit.io
Apache License 2.0
6.76k stars 878 forks source link

[Question] NVIDIA NeMo Guardrails compatibility #496

Closed austinmw closed 10 months ago

austinmw commented 11 months ago

Hi, is this library compatible with NVIDIA NeMo Guardrails?

Guardrails can wrap LangChain, and I'd love to then wrap that in a Chainlit interface. Thanks!

Here's a quick example from here:

from nemoguardrails import LLMRails, RailsConfig

config = RailsConfig.from_path("path/to/config")
app = LLMRails(config)

# ... initialize `docsearch`

qa_chain = RetrievalQA.from_chain_type(
    llm=app.llm, chain_type="stuff", retriever=docsearch.as_retriever()
)
app.register_action(qa_chain, name="qa_chain")

history = [
    {"role": "user", "content": "What is the current unemployment rate?"}
]
result = app.generate(messages=history)
print(result)

I got this far, but am stuck on figuring out how to add callbacks:

from langchain import PromptTemplate, LLMChain
from langchain.chat_models import ChatOpenAI
import chainlit as cl
from langchain.chains import RetrievalQA
from langchain.document_loaders import TextLoader
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import Chroma
from nemoguardrails import LLMRails, RailsConfig

@cl.on_chat_start
def main():
    # Instantiate the chain for that user session
    loader = TextLoader("./paul_graham_essay.txt")
    documents = loader.load()
    text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
    texts = text_splitter.split_documents(documents)
    embeddings = OpenAIEmbeddings()
    docsearch = Chroma.from_documents(texts, embeddings)
    llm = ChatOpenAI(model='gpt-4')
    qa_chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=docsearch.as_retriever(), verbose=True)

    config = RailsConfig.from_path("./config/example")
    app = LLMRails(config)
    app.register_action(qa_chain, name="qa_chain")

    # Store the chain in the user session
    cl.user_session.set("app", app)

@cl.on_message
async def main(message: cl.Message):
    # Retrieve the chain from the user session
    app = cl.user_session.get("app")  # type: qa_chain

    history = [
        {"role": "user", "content": message.content}
    ]

    # Call the chain asynchronously
    res = await app.generate_async(message.content, callbacks=[cl.AsyncLangchainCallbackHandler()])

    # Do any post processing here

    # "res" is a Dict. For this chain, we get the response by reading the "text" key.
    # This varies from chain to chain, you should check which key to read.
    await cl.Message(content=res["content"]).send()
willydouhard commented 10 months ago

if you can't pass the callback when running the chain you can try passing it when you instantiate it (in on_chat_start)

austinmw commented 10 months ago

Worked, thanks!