langchain-ai / langgraph

Build resilient language agents as graphs.
https://langchain-ai.github.io/langgraph/
MIT License
5.89k stars 927 forks source link

Unable to run agent with fine-tuned gpt-4o-mini model #1719

Closed DhruvCMH closed 2 weeks ago

DhruvCMH commented 2 weeks ago

Checked other resources

Example Code

# Configure logger for the module
logger = logging.getLogger(__name__)

# Define embeddings models
embeddings_model = OpenAIEmbeddings(model=EMBEDDINGS_MODEL_NAME, api_key=OPENAI_API_KEY)

# Define language model
llm = ChatOpenAI(model=LLM_MODEL_NAME, openai_api_key=OPENAI_API_KEY, temperature=LLM_TEMPERATURE, seed=SEED, stream_usage=True)

# Statefully manage chat history
store = {}

# Define text splitter
text_splitter = RecursiveCharacterTextSplitter(chunk_size=TEXT_SPLITTER_CHUNK_SIZE)

# Define message trimmer
trimmer = trim_messages(
    max_tokens=TRIM_MESSAGES_MAX_TOKENS,
    strategy="last",
    token_counter=llm,
    include_system=True,
    allow_partial=False,
    start_on="human",
)

# Define message filter
filter_ = filter_messages(exclude_types=["tool"])

# Useful for processing data if it is scraped in markdown format
EXCLUDE_CONTENT = list()
if SCRAPED_DATA_PROCESSING["process_data"]:
    exclude_content_file_path = SCRAPED_DATA_PROCESSING["process_data_file_path"]

    # Read file
    with open(exclude_content_file_path, 'r') as file:
        lines = file.readlines()

    EXCLUDE_CONTENT = [line.strip() for line in lines if line.strip()]
    print(EXCLUDE_CONTENT)

def split_document(docs):
    """
    Split the document into smaller chunks.
    """
    try:
        logger.debug(f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')} splitting document")
        all_splits = text_splitter.split_documents(docs)
        logger.debug(f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')} successfully split document")
    except Exception as e:
        all_splits = list()
        logger.info(f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')} failed to split document")
        logger.exception(f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')} exception caused - {e}")
    return all_splits

def create_new_document(data):
    """
    Custom document creater from scraped markdown data.
    """
    page_content = ""
    for i in data:
        if i.page_content not in EXCLUDE_CONTENT:
            page_content += "\n\n" + i.page_content
            if "link_urls" in i.metadata.keys():
                for l, m in zip(i.metadata["link_texts"], i.metadata["link_urls"]):
                    page_content += "\n" + l + " " + m + "\n"
    doc = Document(page_content=page_content.strip(), metadata={"source": data[0].metadata["source"]})
    return doc

def load_knowledge_base():
    """
    Load the knowledge base.
    """
    try:
        logger.info(f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')} loading knowledge base")
        all_splits = []
        for file in os.listdir(KNOWLEDGE_BASE_DIR):
            logger.info(f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')} processing file {file}")
            file_path = os.path.join(KNOWLEDGE_BASE_DIR, file)
            if file_path.endswith(".pdf"):
                loader = PyPDFLoader(file_path)
                docs = loader.load()
                splits = split_document(docs)
                all_splits.extend(splits)
            elif file_path.endswith(".csv"):
                loader = CSVLoader(file_path=file_path)
                docs = loader.load()
                all_splits.extend(docs)
            elif file_path.endswith(".md"):
                loader = UnstructuredMarkdownLoader(file_path=file_path, mode="elements")
                data = loader.load()
                if len(data):
                    doc = create_new_document(data)
                    all_splits.append(doc)
            logger.info(f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')} successfully processed file {file}")
        logger.info(f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')} successfully loaded knowledge base")
        logger.info(f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')} total number of knowledge chunks - {len(all_splits)}")
    except Exception as e:
        all_splits = list()
        logger.info(f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')} failed to load knowledge base")
        logger.exception(f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')} exception caused - {e}")
    return all_splits

def create_retriever():
    """
    Create the retriever.
    """
    try:
        logger.info(f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')} creating retriever")
        docs = load_knowledge_base()

        # Associate summaries with the documents
        chain = (
                {"doc": lambda x: x.page_content}
                | ChatPromptTemplate.from_template("Summarize the following document:\n\n{doc}")
                | llm
                | StrOutputParser()
        )
        summaries = chain.batch(docs, {"max_concurrency": 5})

        # The vectorstore to use to index the child chunks
        vectorstore = Chroma(embedding_function=embeddings_model, persist_directory=VECTOR_DB_DIR)

        # The storage layer for the parent documents
        store = LocalFileStore(PARENT_DOCUMENTS_STORAGE)
        id_key = "doc_id"

        # The retriever (empty to start)
        retriever = MultiVectorRetriever(
            vectorstore=vectorstore,
            byte_store=store,
            id_key=id_key,
        )
        doc_ids = [str(uuid.uuid4()) for _ in docs]
        summary_docs = [
            Document(page_content=s, metadata={id_key: doc_ids[i]})
            for i, s in enumerate(summaries)
        ]

        retriever.vectorstore.add_documents(summary_docs)
        retriever.docstore.mset(list(zip(doc_ids, docs)))

        for i, doc in enumerate(docs):
            doc.metadata[id_key] = doc_ids[i]
        retriever.vectorstore.add_documents(docs)

        logger.info(f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')} successfully created retriever")
    except Exception as e:
        retriever = None
        logger.info(f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')} failed to create retriever")
        logger.exception(f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')} exception caused - {e}")
    return retriever

def load_retriever():
    """
    Load the retriever.
    """
    try:
        logger.info(f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')} loading retriever")
        if os.path.exists(VECTOR_DB_DIR) and os.path.exists(PARENT_DOCUMENTS_STORAGE):
            logger.info(f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')} retriever already exist")

            # The vectorstore which indexed the child chunks
            vector_db = Chroma(embedding_function=embeddings_model, persist_directory=VECTOR_DB_DIR)

            # The storage layer for the parent documents
            store = LocalFileStore(PARENT_DOCUMENTS_STORAGE)
            id_key = "doc_id"

            # The retriever (empty to start)
            retriever = MultiVectorRetriever(
                vectorstore=vector_db,
                byte_store=store,
                id_key=id_key,
            )
        else:
            logger.info(f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')} retriever does not exist")
            retriever = create_retriever()
        logger.info(f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')} successfully loaded retriever")
    except Exception as e:
        retriever = None
        logger.info(f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')} failed to load retriever")
        logger.exception(f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')} exception caused - {e}")
    return retriever

def format_docs(docs):
    """
    Format the documents.
    """
    return "\n\n".join(doc.page_content for doc in docs)

def get_session_history(session_id: str) -> BaseChatMessageHistory:
    """
    Get the chat message history for the session.
    """
    if session_id not in store:
        store[session_id] = ChatMessageHistory()
    return store[session_id]

def create_follow_up_question_generator_tool():
    prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                PROMPT_FOR_FOLLOW_UP_QUESTION_GENERATION,
            ),
            MessagesPlaceholder(variable_name="messages"),
        ]
    )

    parser = StrOutputParser()
    chain = prompt | llm | parser
    follow_up_tool = chain.as_tool(
        name="follow_up_question_generation",
        description=FOLLOW_UP_QUESTION_GENERATION_TOOL_DESCRIPTION
    )
    return follow_up_tool

def create_tool():
    """
    Create the tool to retrieve information from vector db.
    """
    # Create retrieval tool
    retriever = load_retriever()
    tool = create_retriever_tool(
        retriever,
        "wine_business_information",
        RAG_RETRIEVER_TOOL_DESCRIPTION
    )

    # Create search tool
    search = GoogleSerperAPIWrapper(serper_api_key=SERPER_API_KEY)
    search_tool = Tool(
        name="online_web_search",
        func=search.run,
        description=SEARCH_TOOL_DESCRIPTION,
    )

    # Create follow up question generator tool
    # follow_up_tool = create_follow_up_question_generator_tool()
    tools = [tool, search_tool]
    # tools = [tool, search_tool, follow_up_tool]
    return tools

def find_consecutive_tool_messages_count(messages):
    """
    Find the count of ToolMessage in the conversation.
    """
    idx = 0
    count = 0
    while isinstance(messages[idx], ToolMessage):
        count += 1
        if idx < len(messages) - 1:
            idx += 1
        else:
            break
    return count

def find_occurrences_tool_messages(messages):
    """
    Find the consecutive occurrences of ToolMessage in the conversation.
    """
    occurrence_count = 0
    in_tool_message_sequence = False

    for message in messages:
        if isinstance(message, ToolMessage):
            if not in_tool_message_sequence:
                occurrence_count += 1
                in_tool_message_sequence = True
        else:
            in_tool_message_sequence = False

    return occurrence_count

def keep_last_messages(messages):
    """
    Keep the last messages in the conversation.
    """
    updated_messages = list()
    c = 0
    for message in messages[::-1]:
        if c < KEEP_MESSAGES_COUNT:
            updated_messages.append(message)
        else:
            break
        c += 1
    return updated_messages[::-1]

@chain
def custom_filter_messages(messages):
    """
    Custom filter messages to remove ToolMessage from in between the conversation.
    """
    messages = keep_last_messages(messages)
    total_occurrences_tool_message = find_occurrences_tool_messages(messages)
    updated_messages = list()

    # Add system message if not present.
    if not isinstance(messages[0], SystemMessage):
        updated_messages.append(SystemMessage(INSTRUCTIONS))

    current_occurrence_tool_message = 0
    num_of_tool_messages = 0
    skip_occurrences = False
    for idx, message in enumerate(messages):
        message.content = message.content.replace("\n", " ")
        if isinstance(message, HumanMessage):
            updated_messages.append(message)
        elif isinstance(message, AIMessage):
            if len(message.content):
                updated_messages.append(message)
            else:
                current_occurrence_tool_message += 1
                if current_occurrence_tool_message < total_occurrences_tool_message:
                    num_of_tool_messages = find_consecutive_tool_messages_count(messages[idx + 1:])
                    skip_occurrences = True
                    continue
                else:
                    updated_messages.append(message)
                    num_of_tool_messages = find_consecutive_tool_messages_count(messages[idx + 1:])
                    skip_occurrences = False
        elif isinstance(message, ToolMessage):
            if skip_occurrences:
                if num_of_tool_messages > 0:
                    num_of_tool_messages -= 1
                    continue
                else:
                    skip_occurrences = False
            else:
                updated_messages.append(message)
    return updated_messages

def create_agent():
    """
    Create the agent to answer user's questions.
    """
    # memory = SqliteSaver.from_conn_string("checkpoints.sqlite")
    memory = AsyncSqliteSaver.from_conn_string("checkpoints.sqlite")
    tools = create_tool()
    messages_filter = custom_filter_messages | trimmer
    print(llm)
    agent_executor = create_react_agent(llm, tools, messages_modifier=messages_filter, checkpointer=memory, debug=False)
    return agent_executor

agent = create_agent()

async def generate_streaming_response_from_agent(question: str, session_id: str):
    ignore_tool_output = False
    async for event in agent.astream_events(
        {"messages": [HumanMessage(content=question)]},
        config={"configurable": {"thread_id": session_id}},
        version="v1",
        debug=True
    ):
        kind = event["event"]
        if kind == "on_tool_start":
            ignore_tool_output = True

        elif kind == "on_tool_end":
            ignore_tool_output = False

        elif kind == "on_chat_model_stream":
            if not ignore_tool_output:
                content = event["data"]["chunk"].content
                if content:
                    content = content.replace('\n', '__NEWLINE__')
                    yield f"data: {content}\n\n"

        elif kind == "on_chat_model_end":
            logger.info(f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')} Session ID: {session_id} and User message: {question}")
            logger.info(f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')} Session ID: {session_id} and Chatbot message: {event['data']['output']['generations'][0][0]['message'].content}")
            logger.info(f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')} Session ID: {session_id} and Usage metadata: {event['data']['output']['generations'][0][0]['message'].usage_metadata}")

Error Message and Stack Trace (if applicable)

[-1:checkpoint] State at the end of step -1:
{'messages': []}
[0:tasks] Starting step 0 with 1 task:
- __start__ -> {'messages': [HumanMessage(content='hi')]}
[0:writes] Finished step 0 with writes to 1 channel:
- messages -> [HumanMessage(content='hi')]
[0:checkpoint] State at the end of step 0:
{'messages': [HumanMessage(content='hi', id='4409af73-5303-4fec-b5e3-5d55a8bcfc9f')]}
[1:tasks] Starting step 1 with 1 task:
- agent -> {'is_last_step': False,
 'messages': [HumanMessage(content='hi', id='4409af73-5303-4fec-b5e3-5d55a8bcfc9f')]}
[14/Sep/2024 10:19:11] "GET /stream_reply_from_chatbot/?session_id=53159df5-7a6b-451b-bad4-0d27ab0f4f04&user_message=hi HTTP/1.1" 200 233

Description

System Info

System Information

OS: Linux OS Version: #128-Ubuntu SMP Fri Jul 5 09:28:59 UTC 2024 Python Version: 3.11.9 (main, Apr 6 2024, 17:59:24) [GCC 11.4.0]

Package Information

langchain_core: 0.3.0 langchain: 0.3.0 langchain_community: 0.3.0 langsmith: 0.1.120 langchain_chroma: 0.1.4 langchain_cohere: 0.3.0 langchain_experimental: 0.3.0 langchain_openai: 0.2.0 langchain_text_splitters: 0.3.0 langgraph: 0.2.21

Optional packages not installed

langserve

Other Dependencies

aiohttp: 3.10.5 async-timeout: Installed. No version info available. chromadb: 0.5.3 cohere: 5.9.2 dataclasses-json: 0.6.7 fastapi: 0.114.2 httpx: 0.27.2 jsonpatch: 1.33 langgraph-checkpoint: 1.0.9 numpy: 2.1.1 openai: 1.45.0 orjson: 3.10.7 packaging: 24.1 pandas: 2.2.2 pydantic: 2.9.1 pydantic-settings: 2.5.2 PyYAML: 6.0.2 requests: 2.32.3 SQLAlchemy: 2.0.34 tabulate: 0.9.0 tenacity: 9.0.0 tiktoken: 0.7.0 typing-extensions: 4.12.2

DhruvCMH commented 2 weeks ago

On further inspection I noticed that, when I run this with default gpt-4o-mini model in debug mode, I see this in console:

[-1:checkpoint] State at the end of step -1:
{'messages': []}
[0:tasks] Starting step 0 with 1 task:
- __start__ -> {'messages': [HumanMessage(content='hi')]}
[0:writes] Finished step 0 with writes to 1 channel:
- messages -> [HumanMessage(content='hi')]
[0:checkpoint] State at the end of step 0:
{'messages': [HumanMessage(content='hi', id='e2b1b41b-7a36-4f13-b186-2f14c6c2e959')]}
[1:tasks] Starting step 1 with 1 task:
- agent -> {'is_last_step': False,
 'messages': [HumanMessage(content='hi', id='e2b1b41b-7a36-4f13-b186-2f14c6c2e959')]}
[1:writes] Finished step 1 with writes to 1 channel:
- messages -> [AIMessage(content='Hello there! 🍷 How can I help you today? Are you looking for some delightful wine recommendations, perhaps a tasty recipe to pair with your favorite bottle, or maybe some local happenings around Yountville? Let’s chat!', response_metadata={'finish_reason': 'stop', 'model_name': 'gpt-4o-mini-2024-07-18', 'system_fingerprint': 'fp_483d39d857'}, id='run-1e389208-ca9d-40e5-bb83-fa6bf76df051', usage_metadata={'input_tokens': 1155, 'output_tokens': 47, 'total_tokens': 1202})]
[1:checkpoint] State at the end of step 1:
{'messages': [HumanMessage(content='hi', id='e2b1b41b-7a36-4f13-b186-2f14c6c2e959'),
              AIMessage(content='Hello there! 🍷 How can I help you today? Are you looking for some delightful wine recommendations, perhaps a tasty recipe to pair with your favorite bottle, or maybe some local happenings around Yountville? Let’s chat!', response_metadata={'finish_reason': 'stop', 'model_name': 'gpt-4o-mini-2024-07-18', 'system_fingerprint': 'fp_483d39d857'}, id='run-1e389208-ca9d-40e5-bb83-fa6bf76df051', usage_metadata={'input_tokens': 1155, 'output_tokens': 47, 'total_tokens': 1202})]}
[14/Sep/2024 11:09:57] "GET /stream_reply_from_chatbot/?session_id=4d40c59a-1697-446c-b07e-cdaec1c8ced8&user_message=hi HTTP/1.1" 200 583

However, when I run this with fine-tuned model, I see this:

[-1:checkpoint] State at the end of step -1:
{'messages': []}
[0:tasks] Starting step 0 with 1 task:
- __start__ -> {'messages': [HumanMessage(content='hi')]}
[0:writes] Finished step 0 with writes to 1 channel:
- messages -> [HumanMessage(content='hi')]
[0:checkpoint] State at the end of step 0:
{'messages': [HumanMessage(content='hi', id='4409af73-5303-4fec-b5e3-5d55a8bcfc9f')]}
[1:tasks] Starting step 1 with 1 task:
- agent -> {'is_last_step': False,
 'messages': [HumanMessage(content='hi', id='4409af73-5303-4fec-b5e3-5d55a8bcfc9f')]}
[14/Sep/2024 10:19:11] "GET /stream_reply_from_chatbot/?session_id=53159df5-7a6b-451b-bad4-0d27ab0f4f04&user_message=hi HTTP/1.1" 200 233

It seems like the agent is coming to step [1:tasks] Starting step 1 with 1 task: in both cases but unable to move to next step.

DhruvCMH commented 2 weeks ago

Update: I noticed that removing the trimmer from: messages_filter = custom_filter_messages | trimmer makes the model work fine.

DhruvCMH commented 2 weeks ago

It was due to the trimmer. Earlier token_counter in trimmer was defined as per gpt-4o-mini. Since now I am using the fine-tuned model, so the token counter was unavailable for the new fine-tuned model. Hence to fix this, we have to create a new LLM instance of gpt-4o-mini model, and use that as a token counter here.