langchain-ai / langserve

LangServe 🦜️🏓
Other
1.87k stars 207 forks source link

Bug: TypeError: Type is not JSON serializable: Send #721

Open gcalabria opened 1 month ago

gcalabria commented 1 month ago

Error: TypeError: Type is not JSON serializable: Send

I am trying to replicate the langgraph map-reduce example using langserve and langgraph, but then I get the error TypeError: Type is not JSON serializable: Send.

I only get this error when using langserve. If I run the code in the notebook, then everything is fine. Thus, I think that the problem lies with langserve and not langgraph.

Here is the code I am using:

graph.py

import operator
from typing import Annotated, List, TypedDict, Union

from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
from langchain_core.pydantic_v1 import BaseModel
from langchain_openai import ChatOpenAI
from langgraph.constants import Send
from langgraph.graph import END, START, StateGraph

# Model and prompts
# Define model and prompts we will use
subjects_prompt = """Generate a comma separated list of between 2 and 5 examples related to: {topic}."""
joke_prompt = """Generate a joke about {subject}"""
best_joke_prompt = """Below are a bunch of jokes about {topic}. Select the best one! Return the ID of the best one.

{jokes}"""

class Subjects(BaseModel):
    subjects: list[str]

class Joke(BaseModel):
    joke: str

class BestJoke(BaseModel):
    id: int

model = ChatOpenAI(model="gpt-4o-mini")

# Graph components: define the components that will make up the graph

# This will be the overall state of the main graph.
# It will contain a topic (which we expect the user to provide)
# and then will generate a list of subjects, and then a joke for
# each subject
class OverallState(TypedDict):
    messages: List[Union[HumanMessage, AIMessage, SystemMessage]]
    company_profile: str
    topic: str
    subjects: list
    # Notice here we use the operator.add
    # This is because we want combine all the jokes we generate
    # from individual nodes back into one list - this is essentially
    # the "reduce" part
    jokes: Annotated[list, operator.add]
    best_selected_joke: str

# This will be the state of the node that we will "map" all
# subjects to in order to generate a joke
class JokeState(TypedDict):
    subject: str

# This is the function we will use to generate the subjects of the jokes
def generate_topics(state: OverallState):
    topic = state["messages"][-1].content
    print(f"📚 Topic: {topic}")
    prompt = subjects_prompt.format(topic=topic)
    response = model.with_structured_output(Subjects).invoke(prompt)
    return {"subjects": response.subjects}

# Here we generate a joke, given a subject
def generate_joke(state: JokeState):
    prompt = joke_prompt.format(subject=state["subject"])
    response = model.with_structured_output(Joke).invoke(prompt)
    return {"jokes": [response.joke]}

# Here we define the logic to map out over the generated subjects
# We will use this an edge in the graph
def continue_to_jokes(state: OverallState):
    # We will return a list of `Send` objects
    # Each `Send` object consists of the name of a node in the graph
    # as well as the state to send to that node
    return [Send("generate_joke", {"subject": s}) for s in state["subjects"]]

# Here we will judge the best joke
def best_joke(state: OverallState):
    jokes = "\n\n".join(state["jokes"])
    prompt = best_joke_prompt.format(topic=state["topic"], jokes=jokes)
    response = model.with_structured_output(BestJoke).invoke(prompt)
    return {"best_selected_joke": state["jokes"][response.id]}

def create_graph():
    # Construct the graph: here we put everything together to construct our graph
    graph = StateGraph(OverallState)
    graph.add_node("generate_topics", generate_topics)
    graph.add_node("generate_joke", generate_joke)
    graph.add_node("best_joke", best_joke)
    graph.add_edge(START, "generate_topics")
    graph.add_conditional_edges("generate_topics", continue_to_jokes, ["generate_joke"])
    graph.add_edge("generate_joke", "best_joke")
    graph.add_edge("best_joke", END)
    app = graph.compile()

    return app

server.py

from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from langserve import add_routes

class ChatRequestType(BaseModel):
    messages: List[Union[HumanMessage, AIMessage, SystemMessage]]

# Load environment variables from .env file
load_dotenv()

app = FastAPI(
    title="Example Bot API",
    version="1.0",
    description="Backend for the Example Bot",
)

# Configure CORS
origins = [
    "http://localhost",
    "http://localhost:3000",
]

app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

graph = create_graph()
runnable = graph.with_types(
    input_type=ChatRequestType, output_type=dict
)
add_routes(
    app,
    runnable,
    path="/graph",
    playground_type="chat",
    include_callback_events=True,
)

Here is the error message:

  +-+---------------- 1 ----------------
    | Traceback (most recent call last):
    |   File "/Users/gui/.pyenv/versions/3.12.3/envs/advanced-rag-experiments/lib/python3.12/site-packages/langserve/serialization.py", line 90, in default
    |     return super().default(obj)
    |            ^^^^^^^
    | RuntimeError: super(): __class__ cell not found
    | 
    | The above exception was the direct cause of the following exception:
    | 
    | Traceback (most recent call last):
    |   File "/Users/gui/.pyenv/versions/3.12.3/envs/advanced-rag-experiments/lib/python3.12/site-packages/sse_starlette/sse.py", line 273, in wrap
    |     await func()
    |   File "/Users/gui/.pyenv/versions/3.12.3/envs/advanced-rag-experiments/lib/python3.12/site-packages/sse_starlette/sse.py", line 253, in stream_response
    |     async for data in self.body_iterator:
    |   File "/Users/gui/.pyenv/versions/3.12.3/envs/advanced-rag-experiments/lib/python3.12/site-packages/langserve/api_handler.py", line 1352, in _stream_events
    |     "data": self._serializer.dumps(event).decode("utf-8"),
    |             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    |   File "/Users/gui/.pyenv/versions/3.12.3/envs/advanced-rag-experiments/lib/python3.12/site-packages/langserve/serialization.py", line 168, in dumps
    |     return orjson.dumps(obj, default=default)
    |            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    | TypeError: Type is not JSON serializable: Send
    +------------------------------------
gcalabria commented 1 month ago

I've opened a PR to fix this.

deeplathiya commented 1 month ago

Is there any way to solve this without changing the code in python packages? For example, changing something from our side insted of changing in packages.

eyurtsev commented 1 month ago

Yes serialization issues should be solvable by suffixing a runnable generator or runnable lambda to dump the object into json.

At the moment, langserve is not considered to be compatible with langgraph -- https://github.com/langchain-ai/langserve?tab=readme-ov-file#%EF%B8%8F-langgraph-compatibility

The APIHandler itself has a serializer property that isn't exposed, but one solution would be to expose it to allow users to pass custom hooks for serde.

dsculptor commented 1 month ago

A workaround which works is to apply monkey patch to the default function, very much like your PR @gcalabria:

from typing import Any
import langserve.serialization
from langgraph.constants import Send
from pydantic.v1 import BaseModel

def custom_default(obj) -> Any:
    """Custom serialization for well known objects."""
    if isinstance(obj, BaseModel):
        return obj.dict()

    if isinstance(obj, Send):
        return {"node": obj.node, "arg": obj.arg}

    # Any class which has a `__dict__()` method automatically becomes json serializable.
    try:
        return obj.__dict__()
    except:
        pass

    raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable")

# Monkey patch the default function
langserve.serialization.default = custom_default