langchain-ai / langchain

🦜🔗 Build context-aware reasoning applications
https://python.langchain.com
MIT License
94.57k stars 15.3k forks source link

Issue: How to handle RouterChain when 1 or more destination chain(s) which is expecting a different input variable? #6931

Closed deckikwok closed 1 year ago

deckikwok commented 1 year ago

Issue you'd like to raise.

Would like to ask how we should deal with multiple destination chains where the chains are expecting different input variables?

For e.g. in the tutorial for MultiPromptChain, i would like math questions to be directed to the PalChain instead of the standard LLMChain. With the initial LLMRouterChain, the router prompt uses input as the input_variables. however, when it has decided that the input what is 2+2 is a math question and should be routed to PalChain, i am presented the error

[/usr/local/lib/python3.10/dist-packages/langchain/chains/base.py](https://localhost:8080/#) in _validate_inputs(self, inputs)
    101         missing_keys = set(self.input_keys).difference(inputs)
    102         if missing_keys:
--> 103             raise ValueError(f"Missing some input keys: {missing_keys}")
    104 
    105     def _validate_outputs(self, outputs: Dict[str, Any]) -> None:

ValueError: Missing some input keys: {'question'}

manually replacing the MATH_PROMPT that PalChain uses from {question} to {input} works but I would like to know how I can specify the input variable that the destination chain is expecting when setting up the destination_chains array here:

chain = MultiPromptChain(router_chain=router_chain, 
                         destination_chains=destination_chains, 
                         default_chain=default_chain, 
                         verbose=True
                        )

been at it for 2 nights so am seeking help. thanks!

Suggestion:

No response

rjarun8 commented 1 year ago

Let's give it a try based on what I have understood.

It seems that you are facing an issue with the RouterChain when using multiple destination chains with different input variables. In your case, you want to route math questions to the PalChain instead of the standard LLMChain. However, you are encountering a ValueError due to missing input keys.

To resolve this issue, you can create a custom chain that adapts the input variables for the destination chain. This custom chain will take the input variable from the router chain and convert it to the expected input variable for the destination chain. Here's an example of how you can create such a custom chain:

from langchain.chains.base import Chain

class InputAdapterChain(Chain):
    def __init__(self, destination_chain, input_key_map):
        self.destination_chain = destination_chain
        self.input_key_map = input_key_map

    def run(self, inputs):
        adapted_inputs = {self.input_key_map[k]: v for k, v in inputs.items()}
        return self.destination_chain.run(adapted_inputs)

Now, you can use this InputAdapterChain to wrap the PalChain and adapt the input variable:

pal_chain = PalChain()  # Assuming you have initialized the PalChain
input_key_map = {"input": "question"}
adapted_pal_chain = InputAdapterChain(pal_chain, input_key_map)

Finally, use the adapted_pal_chain in your MultiPromptChain setup:

chain = MultiPromptChain(router_chain=router_chain, 
                         destination_chains=[adapted_pal_chain], 
                         default_chain=default_chain, 
                         verbose=True
                        )

This way, the InputAdapterChain will convert the input variable from the router chain to the expected input variable for the PalChain, avoiding the ValueError you encountered.

This may not reflect your exact implementation. Treat it as a sudo code and let me know if the approach helps.

deckikwok commented 1 year ago

@rjarun8 thanks for the prompt (no pun intended) reply was trying your recommendation but MultiPromptChain's destination chain is expecting LLMChains and not just any other Chain. the library checks tt LLMChain should not have extra attributes and should stick with prompt and llm attr which i agree we should conform

Let me step through the codes to see what can be done

ali-faiz-brainx commented 1 year ago

Hi @deckikwok, can you please provide you're solution? I'm trying to use load_qa_chain with multiPromptChain. But it is through error or unexpected input variables.

deckikwok commented 1 year ago

@ali-faiz-brainx

I ended up implementing my own CustomMultiRouteChain extending from the base MultiRouteChain this way you can have base Chains in the destination_chains such as the InputAdapterChain..

In the InputAdapterChain i simply modify the inputs attributes in the _call () method to the desired data to be passed to the destination chain

I've shared snippets of the codes below with the relevant attr/fx:

class DKMultiPromptChain (MultiRouteChain):

    destination_chains: Mapping[str, Chain]
    """Map of name to candidate chains that inputs can be routed to. Not restricted to LLM"""

class InputConverterChain(Chain):
    destination_chain: Chain = None
    input_key_map: Dict[str, Any] = None

    def _call(
            self,
            inputs: Dict[str, Any],
            run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> Dict[str, str]:

        for k, v in self.input_key_map.items():
            if k in inputs.keys():
                inputs[v] = inputs[k]

        print("##### inputs is now")
        print(inputs)
        # return self.destination_chain.run(adapted_inputs)

        data = self.destination_chain(inputs)
ali-faiz-brainx commented 1 year ago

Thanks for helping, Let me try it. Can you please share the code snippet with me on how you pass extra input fields in the chain? I'm a newbie in learning the LLM. And didn't find anything on their documentation.

ali-faiz-brainx commented 1 year ago

@deckikwok It shows me the following error Can't instantiate abstract class InputConverterChain with abstract methods input_keys, output_keys

Here is the code I tried:

from typing import Dict, Any, Optional, Mapping
from langchain.callbacks.manager import (
    CallbackManagerForChainRun,
)
from langchain.chains.router.base import MultiRouteChain

class DKMultiPromptChain (MultiRouteChain):

    destination_chains: Mapping[str, Chain]
    """Map of name to candidate chains that inputs can be routed to. Not restricted to LLM"""

class InputConverterChain(Chain):
    destination_chain: Chain = None
    input_key_map: Dict[str, Any] = None

    def _call(
            self,
            inputs: Dict[str, Any],
            run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> Dict[str, str]:

        for k, v in self.input_key_map.items():
            if k in inputs.keys():
                inputs[v] = inputs[k]

        print("##### inputs is now")
        print(inputs)
        # return self.destination_chain.run(adapted_inputs)

        data = self.destination_chain(inputs)

Here ist the destination chains:

destination_chains = {}

name = prompt_infos[0]["name"]
prompt_template = prompt_infos[0]["prompt_template"]
prompt = ChatPromptTemplate.from_template(template=prompt_template)
# chain = LLMChain(llm=llm, prompt=prompt)
qa_chain = load_qa_with_sources_chain(
    llm=llm,
    chain_type="stuff",
    prompt=prompt,
    verbose=True
)
input_key_map = {"input": "question"}
chain = InputConverterChain(qa_chain, input_key_map)
destination_chains[name] = chain

name = prompt_infos[1]["name"]
prompt_template = prompt_infos[1]["prompt_template"]
prompt = ChatPromptTemplate.from_template(template=prompt_template)
chain = LLMChain(llm=llm, prompt=prompt)
destination_chains[name] = chain
deckikwok commented 1 year ago

sorry been working on other stuff and trying out langchain on JS instead.

try this for the InputConverterChain. basically u will need to implement the abstract methods of Chain class.

`class InputConverterChain(Chain): destination_chain: Chain = None input_key_map: Dict[str, Any] = None

# __fields_set__: bool = False

class Config:
    """Configuration for this pydantic object."""
    extra = Extra.allow
    arbitrary_types_allowed = True

@property
def input_keys(self) -> List[str]:
    """Will be whatever keys the prompt expects.
    :meta private:
    """
    return ["input"]

@property
def output_keys(self) -> List[str]:
    """Will always return text key.
    :meta private:
    """
    return ["text"]

def _call(
        self,
        inputs: Dict[str, Any],
        run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:

    for k, v in self.input_key_map.items():
        if k in inputs.keys():
            inputs[v] = inputs[k]

    print("##### inputs is now")
    print(inputs)
    # return self.destination_chain.run(adapted_inputs)

    data = self.destination_chain(inputs)`
asantos00 commented 1 year ago

I'm having a similar problem but it seems for some reason the RouterChain is "converting" the destination_chains from a custom chain to a LLM Chain

Print prior to execution:

destination_chains = {}
for p_info in prompt_sources:
  name = p_info["name"]
  prompt_template = p_info["prompt_template"]
  prompt = PromptTemplate(template=prompt_template, input_variables=["input"])
  dest_chain = MyCustomChain(llm=OpenAI(), prompt=prompt)
  destination_chains[name] = dest_chain
print(destination_chains)

Prints this

{'personal-trainer': MyCustomChain(memory=None, callbacks=None, callback_manager=None, verbose=False, tags=None, metadata=None, prompt=PromptTemplate(input_variables=['input'], output_parser=None, partial_variables={}, template="You're a personal trainer that's answering a question according to your profile and philosophy.\nProfile and philosophy:\n\nHuman: {input}\nPersonal trainer:", template_format='f-string', validate_template=True), llm=OpenAI(cache=None, verbose=False, callbacks=None, callback_manager=None, tags=None, metadata=None, client=<class 'openai.api_resources.completion.Completion'>, model_name='text-davinci-003', temperature=0.7, max_tokens=256, top_p=1, frequency_penalty=0, presence_penalty=0, n=1, best_of=1, model_kwargs={}, openai_api_key='', openai_api_base='', openai_organization='', openai_proxy='', batch_size=20, request_timeout=None, logit_bias={}, max_retries=6, streaming=False, allowed_special=set(), disallowed_special='all', tiktoken_model_name=None), output_key='text')}

And on the router:

destination personal-trainer <class 'langchain.chains.llm.LLMChain'> memory=None callbacks=None callback_manager=None verbose=False tags=None metadata=None prompt=PromptTemplate(input_variables=['input'], output_parser=None, partial_variables={}, template="You're a personal trainer that's answering a question according to your profile and philosophy.\nProfile and philosophy:\n\nHuman: {input}\nPersonal trainer:", template_format='f-string', validate_template=True) llm=OpenAI(cache=None, verbose=False, callbacks=None, callback_manager=None, tags=None, metadata=None, client=<class 'openai.api_resources.completion.Completion'>, model_name='text-davinci-003', temperature=0.7, max_tokens=256, top_p=1, frequency_penalty=0, presence_penalty=0, n=1, best_of=1, model_kwargs={}, openai_api_key='', openai_api_base='', openai_organization='', openai_proxy='', batch_size=20, request_timeout=None, logit_bias={}, max_retries=6, streaming=False, allowed_special=set(), disallowed_special='all', tiktoken_model_name=None) output_key='text' output_parser=StrOutputParser() return_final_only=True llm_kwargs={}
VALIDATING INPUTS ['input'] {'input': "I'm feeling quite down today and I don't feel like training for 30 minutes"}

This is strange, because when I execute it in separate, the chain works just fine, calling my _call method. When on the router it does take my prompt, but it seems to not use my custom chain, meaning the _call never gets called.