Open devstein opened 1 year ago
langchain/memory/chat_message_histories/cosmos_db.py
To fix the issue with load_messages
returning None and to update the add_user_message
method to perform an increment operation on the original record, I would make the following changes to the CosmosDBChatMessageHistory
class in the langchain/memory/chat_message_histories/cosmos_db.py
file:
load_messages
method to return the loaded messages instead of None:def load_messages(self) -> List[BaseMessage]:
"""Retrieve the messages from Cosmos"""
if not self._container:
raise ValueError("Container not initialized")
try:
from azure.cosmos.exceptions import CosmosHttpResponseError
except ImportError as exc:
raise ImportError(
"You must install the azure-cosmos package to use the CosmosDBChatMessageHistory."
) from exc
try:
item = self._container.read_item(
item=self.session_id, partition_key=self.user_id
)
except CosmosHttpResponseError:
logger.info("no session found")
return []
if "messages" in item and len(item["messages"]) > 0:
self.messages = messages_from_dict(item["messages"])
return self.messages
add_user_message
method to perform an increment operation on the original record:def add_user_message(self, message: str) -> None:
"""Add a user message to the memory."""
existing_messages = self.load_messages()
existing_messages.append(HumanMessage(content=message))
self.upsert_messages()
With these changes, the load_messages
method will return the loaded messages, and the add_user_message
method will perform an increment operation on the original record.
To address the issue with the CosmosDBChatMessageHistory.load_messages
returning None
and the add_user_message
performing a replace operation instead of an increment on the original record, we need to update the BaseChatMessageHistory
class and its implementation in the langchain/schema.py
file.
First, let's add a new method called load_messages
to the BaseChatMessageHistory
class:
class BaseChatMessageHistory(ABC):
...
@abstractmethod
def load_messages(self) -> List[BaseMessage]:
"""Load messages from the store"""
Now, we need to create a new class that implements BaseChatMessageHistory
for CosmosDB. Let's call it CosmosDBChatMessageHistory
. You will need to install and import the necessary libraries for working with CosmosDB.
from azure.cosmos import CosmosClient
class CosmosDBChatMessageHistory(BaseChatMessageHistory):
def __init__(self, cosmos_client: CosmosClient, container_name: str, session_id: str):
self.cosmos_client = cosmos_client
self.container_name = container_name
self.session_id = session_id
@property
def messages(self) -> List[BaseMessage]:
messages = self.load_messages()
return messages
def load_messages(self) -> List[BaseMessage]:
container = self.cosmos_client.get_container_client(self.container_name)
query = f"SELECT * FROM c WHERE c.session_id = '{self.session_id}'"
items = list(container.query_items(query=query, enable_cross_partition_query=True))
if items:
return messages_from_dict(items[0]["messages"])
return []
def add_message(self, message: BaseMessage) -> None:
container = self.cosmos_client.get_container_client(self.container_name)
messages = self.messages
messages.append(message)
item = {
"session_id": self.session_id,
"messages": messages_to_dict(messages),
}
container.upsert_item(item)
def add_user_message(self, message: str) -> None:
message_ = HumanMessage(content=message)
self.add_message(message_)
def add_ai_message(self, message: str) -> None:
message_ = AIMessage(content=message)
self.add_message(message_)
def clear(self) -> None:
container = self.cosmos_client.get_container_client(self.container_name)
query = f"SELECT * FROM c WHERE c.session_id = '{self.session_id}'"
items = list(container.query_items(query=query, enable_cross_partition_query=True))
if items:
container.delete_item(items[0], partition_key=self.session_id)
This new CosmosDBChatMessageHistory
class should fix the issue with the load_messages
method returning None
and the add_user_message
method performing a replace operation instead of an increment on the original record.
Issue you'd like to raise.
When I use CosmosDBChatMessageHistory, the conversation history is stored in DB, but every time load message returns None, and each time I run add_user_message, I perform a replace operation, not an increment on the original record.
Suggestion:
No response