langchain-ai / langchain

🦜🔗 Build context-aware reasoning applications
https://python.langchain.com
MIT License
92.05k stars 14.65k forks source link

RunnableParallel input schema is empty if children runnable input schemas use TypedDict's #24326

Open mobiware opened 1 month ago

mobiware commented 1 month ago

Checked other resources

Example Code

from typing import TypedDict
from langchain_core.runnables import RunnableParallel, RunnableLambda

class Foo(TypedDict):
    foo: str

class InputData(Foo):
    bar: str

def forward_foo(input_data: InputData):
    return input_data["foo"]

def transform_input(input_data: InputData):
    foo = input_data["foo"]
    bar = input_data["bar"]

    return {
        "transformed": foo + bar
    }

foo_runnable = RunnableLambda(forward_foo)
other_runnable = RunnableLambda(transform_input)

parallel = RunnableParallel(
    foo=foo_runnable,
    other=other_runnable,
)

repr(parallel.input_schema.validate({ "foo": "Y", "bar": "Z" }))
# 'RunnableParallel<foo,other>Input()'

# If we remove the type annotations on forward_foo and transform_input
# args, validate() gives us the right result:
# "RunnableParallel<foo,other>Input(foo='Y', bar='Z')"

Error Message and Stack Trace (if applicable)

No response

Description

When using TypedDict subclasses to annotate the arguments of a RunnableParallel children, the RunnableParallel schema isn't correctly inferred from the children's schemas.

The RunnableParallel schema is empty, i.e. parallel.input_schema.schema() outputs:

{'title': 'RunnableParallel<foo,other>Input',
 'type': 'object',
 'properties': {}}

and parallel.input_schema.validate() returns an empty dict for any input.

This is problematic when exposing the RunnableParallel chain using Langserve, because Langserve passes the endpoint input through schema.validate(), which essentially clears any input as it returns an empty dict

The only workarounds we have found so far are either:

System Info

System Information

OS: Darwin OS Version: Darwin Kernel Version 23.5.0: Wed May 1 20:12:58 PDT 2024; root:xnu-10063.121.3~5/RELEASE_ARM64_T6000 Python Version: 3.10.13 (main, Aug 24 2023, 12:59:26) [Clang 15.0.0 (clang-1500.1.0.2.5)]

Package Information

langchain_core: 0.2.20 langchain: 0.2.8 langchain_community: 0.2.7 langsmith: 0.1.88 langchain_anthropic: 0.1.20 langchain_cli: 0.0.25 langchain_openai: 0.1.16 langchain_text_splitters: 0.2.2 langchainhub: 0.1.20 langserve: 0.2.2

Packages not installed (Not Necessarily a Problem)

The following packages were not found:

langgraph

mobiware commented 1 month ago

Another workaround is to use with_types(input_type=InputData):

parallel = RunnableParallel(
    foo=foo_runnable,
    other=other_runnable,
).with_types(input_type=InputData)