jackmpcollins / magentic

Seamlessly integrate LLMs as Python functions
https://magentic.dev/
MIT License
1.93k stars 95 forks source link

LangGraph Compatibility #287

Closed mjrusso closed 1 month ago

mjrusso commented 1 month ago

Here's a (contrived) demo of how to use Magentic with LangGraph, using Magentic's not-yet-advertised Chat class:

%%capture --no-stderr
%pip install -U magentic langgraph
from magentic.chat_model.message import Message
from pydantic import BaseModel, Field
from typing import List, TypedDict

class CustomerNumber(BaseModel):
    """The user's Acme Corporation customer number."""

    # Ideally, we'd provide other validation here.
    customer_number: str = Field(description="The user's customer number")

class PhoneNumber(BaseModel):
    """The user's phone number."""

    # Ideally, we'd provide other validation here.
    phone_number: str = Field(description="The user's phone number")

class State(TypedDict):
    messages: List[Message]
    customer_number: CustomerNumber
    phone_number: PhoneNumber
from magentic import AssistantMessage, SystemMessage, UserMessage, OpenaiChatModel
from magentic.chat import Chat

def solicit_user_information(state: State):
    """
    Solicit information from the user
    """

    chat = Chat(
        messages = [
            SystemMessage(
                """You are an assistant for Acme Corporation. You will provide customer support,
                   but only once the user provides their Acme customer number or their phone number.
                   Solicit the customer number or phone number from the user (if one has not already been provided)."""
            ),
            *state["messages"]
        ],
        output_types=[str, CustomerNumber, PhoneNumber], model=OpenaiChatModel("gpt-4o")
    ).submit()

    response = chat.last_message.content

    messages = [*state["messages"], chat.last_message]

    if type(response) == CustomerNumber:
        return {"customer_number": response, "messages": messages}
    elif type(response) == PhoneNumber:
        return {"phone_number": response, "messages": messages}
    else:
        return {"messages": messages}

def provide_support(state: State):
    """
    Provide customer support
    """
    print("TODO!")
    return state

def can_provide_support(state: State) -> bool:
    """
    Determines if the user provided information needed to proceed with support request
    """

    if state.get("customer_number") or state.get("phone_number"):
        return True
    else:
        return False
from langgraph.graph import StateGraph, START, END
from langgraph.checkpoint.sqlite import SqliteSaver
from langgraph.checkpoint.memory import MemorySaver

import pickle

workflow = StateGraph(State)

workflow.add_node("solicit_user_information", solicit_user_information)
workflow.add_node("interrupt_solicit_user_information", solicit_user_information)
workflow.add_node("provide_support", provide_support)

workflow.add_edge(START, "solicit_user_information")
workflow.add_conditional_edges(
    "solicit_user_information",
    can_provide_support,
    {True: "provide_support", False: "interrupt_solicit_user_information"},
)
workflow.add_edge("interrupt_solicit_user_information", "solicit_user_information")
workflow.add_edge("provide_support", END)

checkpointer = MemorySaver(serde=pickle)

graph = workflow.compile(
    checkpointer=checkpointer,
    interrupt_before=["interrupt_solicit_user_information"],
)

from IPython.display import Image, display

try:
    display(Image(graph.get_graph().draw_mermaid_png()))
except Exception as e:
    print(e)

jpeg

config = {"configurable": {"thread_id": "1"}}

user_input = "Hi there! My name is Michael. What's up?"

events = graph.stream(
    {"messages": [UserMessage(user_input)],
     "customer_number": None,
     "phone_number": None
    },
    config, 
    stream_mode="values"
)

print("-------")

for event in events:
    print(event["messages"][-1].format())
    print("=======")
-------
UserMessage("Hi there! My name is Michael. What's up?")
=======
AssistantMessage("Hi Michael! I'm here to help you with any questions or issues you may have. Could you please provide your Acme customer number or phone number so I can assist you better?")
=======
user_input = "My customer number is 838384848484872"

state = graph.get_state(config)

state.values["messages"] += [UserMessage(content=user_input)]

(next,) = state.next

graph.update_state(config, state.values, as_node=next)

events = graph.stream(
    None,
    config, 
    stream_mode="values"
)

print("-------")

for event in events:
    print(event["messages"][-1].format())
    print("=======")
-------
AssistantMessage(CustomerNumber(customer_number='838384848484872'))
=======
TODO!
AssistantMessage(CustomerNumber(customer_number='838384848484872'))
=======

A few notes:

The one issue I've run into so far is that the internal Magentic message classes are not serializable. If you swap out this line:

checkpointer = MemorySaver(serde=pickle)

For:

checkpointer = SqliteSaver.from_conn_string(":memory:")

Then it crashes at runtime with:

TypeError: Object of type `UserMessage` is not JSON serializable

(Because we're serializing List[Message] in the graph state.)

Conceivably it might also be nice to be able to serialize entire Chat instances in the graph state too.

I'm going deeper with LangGraph and will post any more findings in this thread.

jwd-dev commented 1 month ago

Any progress here? Had the same thoughts.

jackmpcollins commented 1 month ago

Hi @mjrusso @jwd-dev to confirm, the only issue here is that Message objects are not serializable? Would this work if Message was a pydantic model? I just created issue https://github.com/jackmpcollins/magentic/issues/289 for this. If not, could you share what is required to make these serializable please.

mjrusso commented 1 month ago

Thanks @jackmpcollins 👍

Correct, the only rough edge I'm come across so far with using Magentic with LangGraph is that Message objects are not natively serializable.

LangGraph supports using Pydantic models for state (example notebook: https://langchain-ai.github.io/langgraph/how-tos/state-model/). This particular example is using the Pydantic model as the "top level" state container, which is similar but not exactly the same thing. My hunch is that making Message a Pydantic model would work though.

The relevant serde code for LangGraph is here: https://github.com/langchain-ai/langgraph/tree/main/libs/checkpoint

Curiously, I don't see any Pydantic-specific handling here (I must be missing something). There are tests in that folder that do test serializing and deserializing Pydantic models.

Here's the underlying LangGraph SerializerProtocol protocol definition: https://github.com/langchain-ai/langgraph/blob/487157eafa3190dc5b02126aa6a113d08c42acbe/libs/checkpoint/langgraph/checkpoint/serde/base.py#L4-L25

mjrusso commented 1 month ago

I just took a quick look at what LangChain is doing. LangChain's BaseMessage is serializable with LangGraph; it inherits from Serializable:

Based on this, I took a closer look at LangGraph's serde implementation. Here's where the support for LangChain's serializable comes from:

https://github.com/langchain-ai/langgraph/blob/487157eafa3190dc5b02126aa6a113d08c42acbe/libs/checkpoint/langgraph/checkpoint/serde/jsonplus.py#L37-L38

And, next line down is where Pydantic serialization support comes from:

https://github.com/langchain-ai/langgraph/blob/487157eafa3190dc5b02126aa6a113d08c42acbe/libs/checkpoint/langgraph/checkpoint/serde/jsonplus.py#L39-L40

mjrusso commented 1 month ago

If Magentic messages are Pydantic models, and you use a Pydantic model as the top-level state container, then I'm sure that everything will work properly. If instead you were to use a TypedDict as the top-level state container, I'm not entirely sure that would work. I'll rig up a test.

mjrusso commented 1 month ago

Just did some tests, and can confirm that making Message a Pydantic model will work in all cases.


For posterity:

Example using Pydantic models with TypedDict as top-level state container:

%%capture --no-stderr
%pip install -U langgraph
from pydantic import BaseModel, Field
from typing import List, TypedDict

class CustomerNumber(BaseModel):
    """The user's Acme Corporation customer number."""

    customer_number: str = Field(description="The user's customer number")

class PhoneNumber(BaseModel):
    """The user's phone number."""

    phone_number: str = Field(description="The user's phone number")
class State(TypedDict):
    customer_numbers: List[CustomerNumber]
    phone_number: PhoneNumber
def hello_world(state: State):
    print("hello, world!")

    return {
        "customer_numbers": [CustomerNumber(customer_number="12312565656565656565"), CustomerNumber(customer_number="999999999")],
        "phone_number": PhoneNumber(phone_number="444-111-9999")
    }
from langgraph.graph import StateGraph, START, END
from langgraph.checkpoint.sqlite import SqliteSaver

workflow = StateGraph(State)

workflow.add_node("hello_world", hello_world)
workflow.add_edge(START, "hello_world")
workflow.add_edge("hello_world", END)

checkpointer = SqliteSaver.from_conn_string(":memory:")

graph = workflow.compile(checkpointer=checkpointer)
config = {"configurable": {"thread_id": "1"}}

events = graph.stream(
    {"customer_numbers": None, "phone_number": None},
    config, 
    stream_mode="values"
)

print("-------")

for event in events:
    print(event)
    print("=======")
-------
{'customer_numbers': None, 'phone_number': None}
=======
hello, world!
{'customer_numbers': [CustomerNumber(customer_number='12312565656565656565'), CustomerNumber(customer_number='999999999')], 'phone_number': PhoneNumber(phone_number='444-111-9999')}
=======
graph.get_state(config)
StateSnapshot(values={'customer_numbers': [CustomerNumber(customer_number='12312565656565656565'), CustomerNumber(customer_number='999999999')], 'phone_number': PhoneNumber(phone_number='444-111-9999')}, next=(), config={'configurable': {'thread_id': '1', 'thread_ts': '1ef50dbc-0d38-6618-8001-a4fa28356614'}}, metadata={'source': 'loop', 'writes': {'hello_world': {'customer_numbers': [CustomerNumber(customer_number='12312565656565656565'), CustomerNumber(customer_number='999999999')], 'phone_number': PhoneNumber(phone_number='444-111-9999')}}, 'step': 1}, created_at='2024-08-02T14:30:23.492245+00:00', parent_config={'configurable': {'thread_id': '1', 'thread_ts': '1ef50dbc-0d30-658a-8000-f12c9ee64323'}})

Example using Pydantic models with Pydantic BaseModel as top-level state container:

%%capture --no-stderr
%pip install -U langgraph
from pydantic import BaseModel, Field
from typing import List, Optional

class CustomerNumber(BaseModel):
    """The user's Acme Corporation customer number."""

    customer_number: str = Field(description="The user's customer number")

class PhoneNumber(BaseModel):
    """The user's phone number."""

    phone_number: str = Field(description="The user's phone number")
class State(BaseModel):
    customer_numbers: Optional[List[CustomerNumber]]
    phone_number: Optional[PhoneNumber]
def hello_world(state: State):
    print("hello, world!")

    state.customer_numbers = [
        CustomerNumber(customer_number="12312565656565656565"), 
        CustomerNumber(customer_number="999999999")
    ]

    state.phone_number = PhoneNumber(phone_number="444-111-9999")

    return dict(state)
from langgraph.graph import StateGraph, START, END
from langgraph.checkpoint.sqlite import SqliteSaver

workflow = StateGraph(State)

workflow.add_node("hello_world", hello_world)
workflow.add_edge(START, "hello_world")
workflow.add_edge("hello_world", END)

checkpointer = SqliteSaver.from_conn_string(":memory:")

graph = workflow.compile(checkpointer=checkpointer)
config = {"configurable": {"thread_id": "1"}}

events = graph.stream(
    {"customer_numbers": None, "phone_number": None},
    config, 
    stream_mode="values"
)

print("-------")

for event in events:
    print(event)
    print("=======")
-------
{'customer_numbers': None, 'phone_number': None}
=======
hello, world!
{'customer_numbers': [CustomerNumber(customer_number='12312565656565656565'), CustomerNumber(customer_number='999999999')], 'phone_number': PhoneNumber(phone_number='444-111-9999')}
=======
graph.get_state(config)
StateSnapshot(values={'customer_numbers': [CustomerNumber(customer_number='12312565656565656565'), CustomerNumber(customer_number='999999999')], 'phone_number': PhoneNumber(phone_number='444-111-9999')}, next=(), config={'configurable': {'thread_id': '1', 'thread_ts': '1ef50dbc-a4cc-6d3a-8001-e4e32dc68450'}}, metadata={'source': 'loop', 'writes': {'hello_world': {'customer_numbers': [CustomerNumber(customer_number='12312565656565656565'), CustomerNumber(customer_number='999999999')], 'phone_number': PhoneNumber(phone_number='444-111-9999')}}, 'step': 1}, created_at='2024-08-02T14:30:39.386542+00:00', parent_config={'configurable': {'thread_id': '1', 'thread_ts': '1ef50dbc-a4c3-6834-8000-b0052c941eb0'}})
jackmpcollins commented 1 month ago

@mjrusso @jwd-dev I've opened PR https://github.com/jackmpcollins/magentic/pull/294 that makes Message a BaseModel. Please test it out and let me know if that works for you!

mjrusso commented 1 month ago

@jackmpcollins awesome, thank you! Just gave it a spin and it looks like the implementation in #294 is working.

Minimal working example:

%%capture --no-stderr
%pip install -U langgraph
from pydantic import BaseModel
from typing import List
from magentic import AnyMessage

class State(BaseModel):
    messages: List[AnyMessage]
from magentic import SystemMessage, AssistantMessage, UserMessage

def hello_world(state: State):
    print("hello, world!")

    state.messages += [
        UserMessage("What day is it today?"),
        AssistantMessage("As a large language model, I do not have access to the date."),
        UserMessage("No problem, thanks for trying."),
    ] 

    return dict(state)
from langgraph.graph import StateGraph, START, END
from langgraph.checkpoint.sqlite import SqliteSaver

workflow = StateGraph(State)

workflow.add_node("hello_world", hello_world)
workflow.add_edge(START, "hello_world")
workflow.add_edge("hello_world", END)

checkpointer = SqliteSaver.from_conn_string(":memory:")

graph = workflow.compile(checkpointer=checkpointer)
config = {"configurable": {"thread_id": "1"}}

graph.invoke(
    {"messages": [SystemMessage("You are a somewhat helpful assistant.")]},
    config,
    stream_mode="values"
)
hello, world!

{'messages': [SystemMessage('You are a somewhat helpful assistant.'),
  UserMessage('What day is it today?'),
  AssistantMessage('As a large language model, I do not have access to the date.'),
  UserMessage('No problem, thanks for trying.')]}
graph.get_state(config)
StateSnapshot(values={'messages': [SystemMessage('You are a somewhat helpful assistant.'), UserMessage('What day is it today?'), AssistantMessage('As a large language model, I do not have access to the date.'), UserMessage('No problem, thanks for trying.')]}, next=(), config={'configurable': {'thread_id': '1', 'thread_ts': '1ef5419a-3e10-6530-8001-bf3530b5fe14'}}, metadata={'source': 'loop', 'writes': {'hello_world': {'messages': [SystemMessage('You are a somewhat helpful assistant.'), UserMessage('What day is it today?'), AssistantMessage('As a large language model, I do not have access to the date.'), UserMessage('No problem, thanks for trying.')]}}, 'step': 1}, created_at='2024-08-06T17:30:57.212332+00:00', parent_config={'configurable': {'thread_id': '1', 'thread_ts': '1ef5419a-3e03-6308-8000-4e9628157690'}})
jackmpcollins commented 1 month ago

@mjrusso Thanks for confirming! Released https://github.com/jackmpcollins/magentic/releases/tag/v0.29.0 now which includes this. Let me know if you run into any other issues

mjrusso commented 1 month ago

Beauty, thanks @jackmpcollins! 👏 I'll close this ticket for now; will let you know if I run into any other LangGraph-related issues.