langchain-ai / langchain-google

MIT License
97 stars 114 forks source link

VertexAICheckGroundingWrapper breaks streaming #415

Open shenghann opened 1 month ago

shenghann commented 1 month ago

I would like to be able to use this wrapper as part of an LCEL chain that pass both the answer_candidate and documents via the inputs of this wrapper runnable.

The way VertexAICheckGroundingWrapper needs to be invoked with configs is breaking streaming in LCEL chains. Why is it written in a way that requires input of documents/claims via configurables and not through typical RunnableSerializable inputs?

Right now this is how we invoke the wrapper:

# existing
output_parser.with_config(configurable={"documents": documents}).invoke(answer_candidate)

# desired way to invoke
output_parser.invoke({"answer_candidate": answer_candidate, "documents": documents})

In a typical RAG LCEL chain, the only way to incorporate this wrapper to perform check grounding is to wrap this wrapper in a RunnableLambda, and because it requires with_config, the inputs have to be "finalized", which breaks streaming in LCEL. More on how finalized inputs break streaming.

cc contributor @Abhishekbhagwat

Abhishekbhagwat commented 1 month ago

Hi @shenghann, thanks for trying out the VertexAICheckGroundingWrapper. Would you be okay to share a bit more on how you are using it for streaming ? I would like to clarify if what you want to do is avoid passing documents as an additional config parameter ?

AFAIK - Vertex AI Check Grounding API does not support partial grounding check, so your inputs would need to be finalized in any case.

VertexAICheckGroundingWrapper does actually inherit from RunnableSerializable, so you can do this as an alternative as well

output_parser.invoke(answer_candidate, config={'configurable': {'documents': documents}})
shenghann commented 1 month ago

Thanks for the quick response @Abhishekbhagwat!

Here's a minimal example of how the wrapper can be used as part of an LCEL chain:

message = """
Answer this question using the provided context only.

{question}

Context:
{context}
"""

prompt = ChatPromptTemplate.from_messages([("human", message)])

rag_chain = (
    {"context": retriever, "question": RunnablePassthrough()} 
    | RunnableParallel({
        "answer": prompt | model,
        "context": itemgetter("context"),
    })
)
chain = rag_chain | output_parser

chunks = []
for chunk in chain.stream("What is John's profession?"):
    chunks.append(chunk)
    print(chunk, end="", flush=True)

Notice how the rag_chain outputs two keys: answer (as candidate answer) and context (list of documents). I would like to able to pass these two directly as inputs to VertexAICheckGroundingWrapper and have it return the grounding result.

If you run the code above you'll get an error: ValueError: Configuration is required.

What would be the best way for me to incorporate VertexAICheckGroundingWrapper in my streaming chain?

Abhishekbhagwat commented 1 month ago

Hi @shenghann, thanks for the minimal example. This is much more clearer now. As I mentioned earlier, VertexAICheckGroundingWrapper does not support streaming response as it needs the inputs to be finalized. In any case, I understand that the above example breaks your streaming chain. We can still use VertexAICheckGroundingWrapper to return 1 chunk as the entire stream response.

This is an example implementation of how to do this without breaking your streaming chain

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableParallel, RunnableLambda
from operator import itemgetter

# Define the prompt template
message = """
Answer this question using the provided context only.

{question}

Context:
{context}
"""

prompt = ChatPromptTemplate.from_template(message)

# Construct the simplified RAG chain
chain = (
    RunnableParallel(
        {
            "question": RunnablePassthrough(),
            "context": retriever
        }
    )
    | RunnableParallel(
        {
            "response": prompt | llm,
            "context": lambda x: x["context"]  # Pass through the retrieved documents
        }
    )
    | RunnableLambda(lambda x: output_parser.invoke(
        x["response"],
        config={"configurable": {"documents": x["context"]}}
    ))
)

# Run the chain and stream the results
chunks = []
for chunk in chain.stream("What is John's profession?"):
    chunks.append(chunk)
    print(chunk, end="", flush=True)

Do try this out and let me know if it prevents breaking of the chain :)

Abhishekbhagwat commented 1 month ago

Hi @shenghann, just wanted to check in to see if this fixes your issue ? Thanks