Open tomdgr opened 2 hours ago
To resolve the KeyError: 'invoke'
when using the Answer Relevance COT feedback function in a streaming setup with TruLens and Langchain's chain.astream()
function, you need to ensure that the invoke
method is correctly referenced and used. Here is a corrected version of your code snippet that should help resolve the issue:
import logging
import asyncio
from typing import List, Dict
from langchain.memory import ConversationBufferMemory
import os
import numpy as np
from trulens.providers.openai import AzureOpenAI
from trulens.apps.langchain import TruChain
from trulens.core import Select, Feedback, TruSession
from pydantic import BaseModel, Field
from typing import List, Optional
from langchain_core.messages.ai import AIMessageChunk
class FeedbackData(BaseModel):
username: str
feedback: str
class Message(BaseModel):
role: str
content: str
class Conversation(BaseModel):
title: str
id: str
filename: str
indexName: Optional[str] = None
promptType: Optional[str] = None
messages: List[Message]
conversation = Conversation(
title="Ny Samtale",
id="conversation-1234",
filename="sample_file.txt",
indexName="livsvitenskap",
promptType="juridisk ekspert",
messages=[
Message(role="user", content="What is the legal precedent for this case?")
]
)
openai_provider = AzureOpenAI(
deployment_name="gpt4o",
endpoint=os.environ['AZURE_OPENAI_ENDPOINT'],
api_version=AZURE_OPENAI_VERSION
)
tru = TruSession()
tru.reset_database()
tru.start_dashboard()
def feedback_cot_function(input_data, output_data):
cot_relevance = openai_provider.relevance_with_cot_reasons(input_data, output_data)
return cot_relevance
f_feedback_cot = Feedback(feedback_cot_function, name="Answer Relevance COT").on_input().on_output()
def feedback_groundedness_function(context, response):
groundedness = openai_provider.groundedness_measure_with_cot_reasons(context, response)
return groundedness
f_groundedness = (
Feedback(feedback_groundedness_function, name="Groundedness")
.on(Select.Record.app.middle[0].rets)
.on_output()
.aggregate(np.mean)
)
f_context_relevance = (
Feedback(openai_provider.context_relevance_with_cot_reasons, name="Context Relevance")
.on_input()
.on(Select.Record.app.middle[0].rets)
.aggregate(np.mean)
)
async def get_chain(conversation: Conversation, prompt_type: str = "RESPONSE") -> AsyncIterator[str]:
hyper_params = initialize_hyperparameters()
retriever = initialize_retriever(hyper_params=hyper_params, conversation=conversation)
prompt = construct_prompt(conversation=conversation, prompt_type=prompt_type, hyper_params=hyper_params)
llm = initialize_llm(hyper_params=hyper_params)
context_chain = RunnableLambda(lambda x: retriever.get_relevant_documents(x["conversation"]))
query_chain = RunnableLambda(lambda x: create_query(x["conversation"]))
final_chain = (
RunnableParallel(
{
"context": context_chain,
"query": query_chain,
"history": itemgetter("history"),
}
)
.assign(response=prompt | llm)
.pick(["response", "context"])
)
return final_chain
chain = await get_chain(conversation=conversation, prompt_type="RESPONSE")
tru_recorder = TruChain(
app_name="testapp",
app=chain,
provider=openai_provider,
feedbacks=[f_groundedness, f_feedback_cot, f_context_relevance]
)
memory = ConversationBufferMemory(memory_key="history", return_messages=False)
for message in conversation.messages:
if message.role == "user":
memory.chat_memory.add_user_message(message.content)
elif message.role == "assistant":
memory.chat_memory.add_ai_message(message.content)
history = memory.load_memory_variables({})["history"]
input_data = {
"conversation": str(conversation.messages[-1].content) if conversation.messages else "",
"history": history
}
bot_reply = ""
with tru_recorder as recording:
async for chunk in chain.astream({"conversation": conversation, "history": history}):
if "response" in chunk:
content = chunk["response"]
if isinstance(content, AIMessageChunk):
bot_reply += content.content
else:
raise Exception("Unexpected response format, please check the chain or start a new conversation.")
async def yield_chain_response_with_metrics(conversation, history, chain, tru_recorder):
loop = asyncio.get_running_loop()
original_task_factory = loop.get_task_factory()
loop.set_task_factory(None)
latest_message = conversation.messages[-1].content if conversation.messages else ""
input_data = {
"conversation": latest_message,
"history": history
}
try:
bot_reply = ""
with tru_recorder as recording:
async for chunk in chain.astream({"conversation": conversation, "history": history}):
if "response" in chunk:
content = chunk["response"]
if isinstance(content, AIMessageChunk):
bot_reply += content.content
yield content.content
else:
raise Exception("Something went wrong in yield chunk in chain response, try opening a new chat")
df, feedback_columns_list = tru.get_records_and_feedback(app_ids=["testapp"])
metrics_columns = ['Answer Relevance COT', 'Groundedness', 'Context Relevance']
metrics_dict = df[metrics_columns].iloc[0].to_dict()
metrics_json = json.dumps({'metrics': metrics_dict})
print(metrics_json)
yield f"{metrics_json}\n\n"
except Exception as e:
print("An Unexpected error happened: ", e)
finally:
loop.set_task_factory(original_task_factory)
@chatbot_bp.route(route=f"{prefix}/query_openai_stream", methods=[func.HttpMethod.POST])
async def query_openai_stream_endpoint(req: Request) -> StreamingResponse:
try:
conversation_data = await req.json()
conversation = Conversation(**conversation_data)
memory = ConversationBufferMemory(memory_key="history", return_messages=False)
for message in conversation.messages:
if message.role == "user":
memory.chat_memory.add_user_message(message.content)
elif message.role == "assistant":
memory.chat_memory.add_ai_message(message.content)
history = memory.load_memory_variables({})["history"]
app_id_generated = generate_trulens_app_id()
chain = await get_chain(conversation=conversation, prompt_type="RESPONSE")
tru_recorder = TruChain(
app=chain,
provider=openai_provider,
app_id=app_id_generated,
feedbacks=[f_groundedness, f_feedback_cot, f_context_relevance]
)
return StreamingResponse(yield_chain_response_with_metrics(conversation, history, chain, tru_recorder), media_type="text/event-stream")
except Exception as e:
print(e, "ERROR IN QOSE")
raise HTTPException(status_code=500, detail=str(e))
This code uses chain.astream()
for the streaming setup and ensures that the invoke
method is correctly referenced and used within the streaming context. Additionally, it includes the correct feedback function relevance
for answer relevance, as shown in the LangChain Quickstart example [1][2][3].
To continue talking to Dosu, mention @dosu.
Dosu's suggestion is tried before, but does not work either, would appreciate human help.
Hi @tomdgr - thanks for the question. We have a fixes coming for async and streaming in this PR.
Feel free to install it from the branch if you'd like to try it now, otherwise we'll have it in the next release.
Bug Description
When using the Answer Relevance COT feedback function(Groundedness and Context Relevance also fails) in a streaming setup with TruLens and Langchain's chain.astream(), the feedback evaluation crashes with a KeyError: 'invoke'. This seems to occur because the feedback_cot_function is not compatible with streamed, chunked responses. The function expects a full input and output, but since streaming delivers responses in chunks, the feedback function encounters issues.
NB! I have made this work by using chain.invoke(), however for my prodution environment I need to use astream(). I saw and old issue, which solved the problem with acall. However this function seems to be replaced by ainvoke. Link to the mentioned issue Does anyone here have experience with trulens + langchain + astream - function
The error I get when using astream:
To Reproduce
Expected behavior A clear and concise description of what you expected to happen.
Relevant Logs/Tracebacks
Trace from dashboard when looking at evaluations: Minified React error #31; visit https://reactjs.org/docs/error-decoder.html?invariant=31&args[]=object%20with%20keys%20%7B__tru_non_serialized_object%7D for the full message or use the non-minified dev environment for full errors and additional helpful warnings.
Environment:
langchain 0.2.16 langchain-cli 0.0.31 langchain-community 0.2.17 langchain-core 0.2.40 langchain-openai 0.1.25 langchain-text-splitters 0.2.4
Additional context Here is the production environment code: running in fastAPI
The part involving df, and feedbacks in yield_chain_response_with_metric: is what I want to solve