langchain-ai / langchain

🦜🔗 Build context-aware reasoning applications
https://python.langchain.com
MIT License
89.3k stars 14.08k forks source link

Issue: How to validate Tool input arguments without raising ValidationError #13662

Open deepakagrawal opened 7 months ago

deepakagrawal commented 7 months ago

Issue you'd like to raise.

I am creating a tool that takes a multiple input arguments (say input1, input2). I would like to validate their types and also make sure that the tool only receives input1 and input2. How do I validate this without breaking the llm chain. I would instead like to return a warning to the llm agent. something like, "The input passed were incorrect, please try again"

Suggestion:

No response

dosubot[bot] commented 7 months ago

🤖

Hello @deepakagrawal,

Great to see you again! I remember our previous interaction and I'm glad to assist you once more.

I'm currently looking into the issue you've raised about validating input arguments without breaking the llm chain. I'll provide a detailed response shortly.

dosubot[bot] commented 7 months ago

🤖

Hi @deepakagrawal! Good to see you again on the LangChain repository. Hope you're doing well.

To validate the input arguments of a tool in the LangChain framework without raising a ValidationError, you can modify the validate_query method in your input schema class (which should be a subclass of Pydantic's BaseModel). Instead of raising an error, you can add a warning message to the values dictionary. Here's an example:

from pydantic import BaseModel, Field, root_validator
from typing import Dict, Any

class ToolInputSchema(BaseModel):
    input1: str = Field(...)
    input2: str = Field(...)

    @root_validator
    def validate_inputs(cls, values: Dict[str, Any]) -> Dict:
        input1 = values.get("input1")
        input2 = values.get("input2")
        # Add your validation logic here
        if not isinstance(input1, str) or not isinstance(input2, str):
            values["warning"] = "The input passed were incorrect, please try again"
        return values

In this example, if input1 or input2 are not strings, a warning message is added to the values dictionary under the key "warning". This warning message can then be checked and handled by the llm agent.

This approach is based on the usage of Pydantic's root_validator decorator, which allows you to add custom validation logic to your models. You can find more information about it in the Pydantic documentation.

Please note that this is a general approach and you might need to adjust it according to your specific needs and the structure of your tool. For example, you might need to adjust the validation logic or the structure of the warning message.

Also, remember to set your ToolInputSchema as the args_schema attribute of your tool class, as shown in the BaseTool class in the LangChain codebase:

class MyTool(BaseTool):
    args_schema = ToolInputSchema
    ...

This way, the _parse_input method of the BaseTool class will use your ToolInputSchema to validate and parse the tool's input.

Sources


This response is meant to be useful and save you time. It is not meant to be a precise solution, but rather a starting point for your own research.

Help me be more useful! Please leave a 👍 if this is helpful and 👎 if it is irrelevant.

If you want to continue the conversation, start your reply with @dosu-bot.

deepakagrawal commented 7 months ago

@dosu-bot, can you show an example of how is this "warning" used by the llm agent?

dosubot[bot] commented 7 months ago

🤖

Based on the current implementation of the LangChain framework, the "warning" field you added to the values dictionary in your Pydantic model's root_validator will not be used by the llm agent.

The _parse_input method of the BaseTool class, which is responsible for validating the tool's input arguments, only returns the fields that were present in the original input. This is done by the following line of code:

return {k: v for k, v in result.dict().items() if k in tool_input}

This means that any additional fields added to the Pydantic model during validation, such as your "warning" field, will not be included in the returned dictionary.

If you want to include a warning message in the tool's output, you might need to modify the tool's run method to check for the presence of the "warning" field in the parsed input and include it in the output. However, this would require modifying the LangChain framework's code, which might not be desirable or feasible depending on your situation.

As an alternative, you could raise a ValidationError in your root_validator when the inputs are not valid. This would cause the _parse_input method to raise an exception, which could be caught and handled by the llm agent. The exception's message could include your warning message. Here's how you could modify your root_validator to do this:

from pydantic import ValidationError

@root_validator
def validate_inputs(cls, values: Dict[str, Any]) -> Dict:
    input1 = values.get("input1")
    input2 = values.get("input2")
    if not isinstance(input1, str) or not isinstance(input2, str):
        raise ValidationError("The input passed were incorrect, please try again")
    return values

I hope this helps! If you have any further questions, feel free to ask.

Sources


This response is meant to be useful and save you time. It is not meant to be a precise solution, but rather a starting point for your own research.

Help me be more useful! Please leave a 👍 if this is helpful and 👎 if it is irrelevant.

If you want to continue the conversation, start your reply with @dosu-bot.

deepakagrawal commented 7 months ago

@dosu-bot this breaks the langchain flow due to raise ValidationError.

dosubot[bot] commented 7 months ago

🤖

To validate the input arguments in the LangChain framework and return a warning to the llm agent without breaking the LangChain flow, you can modify the validate_inputs method in the SmartLLMChain class. Instead of raising a ValueError when the inputs are not valid, you can modify the method to add a "warning" field to the values dictionary that will be returned. This "warning" field can contain a string that describes the validation error.

Here is the modified validate_inputs method:

@root_validator  # type: ignore[call-overload]
@classmethod
def validate_inputs(cls, values: Dict[str, Any]) -> Dict[str, Any]:
    """Ensure we have an LLM for each step."""
    llm = values.get("llm")
    ideation_llm = values.get("ideation_llm")
    critique_llm = values.get("critique_llm")
    resolver_llm = values.get("resolver_llm")
    warning = None
    if not llm and not ideation_llm:
        warning = (
            "Either ideation_llm or llm needs to be given. Pass llm, "
            "if you want to use the same llm for all steps, or pass "
            "ideation_llm, critique_llm and resolver_llm if you want "
            "to use different llms for each step."
        )
    elif not llm and not critique_llm:
        warning = (
            "Either critique_llm or llm needs to be given. Pass llm, "
            "if you want to use the same llm for all steps, or pass "
            "ideation_llm, critique_llm and resolver_llm if you want "
            "to use different llms for each step."
        )
    elif not llm and not resolver_llm:
        warning = (
            "Either resolve_llm or llm needs to be given. Pass llm, "
            "if you want to use the same llm for all steps, or pass "
            "ideation_llm, critique_llm and resolver_llm if you want "
            "to use different llms for each step."
        )
    elif llm and ideation_llm and critique_llm and resolver_llm:
        warning = (
            "LLMs are given for each step (ideation_llm, critique_llm,"
            " resolver_llm), but backup LLM (llm) is also given, which"
            " would not be used."
        )
    if warning:
        values["warning"] = warning
    return values

This way, the llm agent can check if the "warning" field is present in the returned dictionary and handle the warning accordingly. You can find this method in the libs/experimental/langchain_experimental/smart_llm/base.py file.

As for handling warnings and errors during the execution of a tool, the LangChain framework uses a combination of exception handling and retry policies. For instance, a RunnableLambda is created with a function that fails a certain percentage of the time. When this function fails, it raises a ValueError. The RunnableLambda is then used in a sequence with a retry policy added using the with_retry method. This policy specifies that the function should be retried up to a certain number of times upon failure.

In addition, the framework provides a return_exceptions parameter in the batch method. If this parameter is set to True, the method will return exceptions instead of raising them. This allows you to handle exceptions in a way that suits your needs.

Finally, the framework provides a with_fallbacks method that allows you to specify a sequence of fallback runnables to try if the original runnable fails. These features provide a robust way to handle errors and warnings during the execution of a tool in the LangChain framework. You can find these methods in the libs/core/langchain_core/runnables/base.py file.

Sources


This response is meant to be useful and save you time. It is not meant to be a precise solution, but rather a starting point for your own research.

Help me be more useful! Please leave a 👍 if this is helpful and 👎 if it is irrelevant.

If you want to continue the conversation, start your reply with @dosu-bot.

deepakagrawal commented 7 months ago

@dosu-bot you are going in loop. This was your initial solution. But that did not work.

dosubot[bot] commented 7 months ago

Based on the information provided, it seems like you're trying to add a warning message to the validation process without breaking the LangChain flow. However, the LangChain framework uses Pydantic for input validation, which raises a ValidationError when the input is not valid. This behavior is inherent to Pydantic and cannot be changed without modifying the Pydantic library itself.

However, there are a few workarounds you could consider:

  1. Custom Error Handling: You could catch the ValidationError raised by Pydantic and handle it in your own code. This would allow you to return a warning message instead of breaking the flow. Here's a simple example:
from pydantic import BaseModel, ValidationError

class MyModel(BaseModel):
    input1: int
    input2: int

def validate_inputs(input1, input2):
    try:
        MyModel(input1=input1, input2=input2)
    except ValidationError as e:
        return {"warning": str(e)}

In this example, the validate_inputs function will return a dictionary with a "warning" field if the inputs are not valid. You can then check for this "warning" field in your code and handle it accordingly.

  1. Use of Fallbacks: As mentioned in the context, the LangChain framework supports fallbacks. If the original runnable fails, the framework will try each fallback in order. You could potentially use this feature to handle invalid inputs. For example, you could have a fallback that simply returns a warning message when the inputs are not valid.

Remember, these are just workarounds and may not be the most elegant solutions. They also require you to handle the warning messages in your own code, which may not be ideal. Unfortunately, without the ability to modify the LangChain framework or the Pydantic library, there's no way to change the default behavior of the input validation process.

Sources

deepakagrawal commented 7 months ago

@dosu-bot where would l call this function validate_inputs?

dosubot[bot] commented 7 months ago

🤖

Based on the structure of the LangChain framework, the validate_inputs function you've created should ideally be called within the _parse_input method of the BaseTool class. This method is responsible for validating the input arguments of a tool.

Here's an example of how you might modify the _parse_input method to incorporate your validate_inputs function:

def _parse_input(
    self,
    tool_input: Union[str, Dict],
) -> Union[str, Dict[str, Any]]:
    """Convert tool input to pydantic model."""
    input_args = self.args_schema
    try:
        if isinstance(tool_input, str):
            if input_args is not None:
                key_ = next(iter(input_args.__fields__.keys()))
                input_args.validate({key_: tool_input})
            return tool_input
        else:
            if input_args is not None:
                result = input_args.parse_obj(tool_input)
                return {k: v for k, v in result.dict().items() if k in tool_input}
    except ValidationError as e:
        warnings.warn(f"Input validation failed: {e}")
    return tool_input

In this modified version of _parse_input, if the validation fails, a warning is issued and the original input is returned. Please note that this modification will change the behavior of your application, as invalid inputs will no longer raise an error but will instead be passed through to the tool.

You can find the _parse_input method in the BaseTool class in the libs/core/langchain_core/tools.py file.

Please note that this is just a suggestion based on the information provided. Depending on the specifics of your validate_inputs function and how you want your application to handle invalid inputs, you might need to adjust this approach.

Sources


This response is meant to be useful and save you time. It is not meant to be a precise solution, but rather a starting point for your own research.

Help me be more useful! Please leave a 👍 if this is helpful and 👎 if it is irrelevant.

If you want to continue the conversation, start your reply with @dosu-bot.

ahmadaii commented 7 months ago

@deepakagrawal Were you be able to handle this? I'm facing similar validation problem :)

ahmadaii commented 7 months ago

@deepakagrawal Have you looked into "Handle Tool Error" section of this doc?

hmasdev commented 7 months ago

@deepakagrawal @ahmadaii I also faced a similar validation problem. However, I have solved it. So, I share a simple instant solution with you.

Python 3.9.13 (tags/v3.9.13:6de2ca5, May 17 2022, 16:36:42) [MSC v.1929 64 bit (AMD64)] on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import langchain
>>> langchain.__version__
'0.0.330'
>>> from functools import wraps
>>> from langchain.tools import StructuredTool
>>> from langchain import pydantic_v1 as pydantic
>>> # HERE IS A SAMPLE TOOL
>>> def joint(a: str, b: str) -> str:
...     '''joint two strings'''
...     return a + b
...
>>> joint_tool = StructuredTool.from_function(joint, name=joint.__name__, description=joint.__doc__, handle_tool_errors=True)  
>>> # PROBLEM REPRODUCTION
>>> joint_tool.run(dict(a='a', b='b'))
'ab'
>>> joint_tool.run(dict(a='a')) 
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "D:\dev_libs\langchain\libs\core\langchain_core\tools.py", line 311, in run
    parsed_input = self._parse_input(tool_input)
  File "D:\dev_libs\langchain\libs\core\langchain_core\tools.py", line 246, in _parse_input
    result = input_args.parse_obj(tool_input)
  File "D:\dev_libs\langchain\venv\lib\site-packages\pydantic\v1\main.py", line 526, in parse_obj
    return cls(**obj)
  File "D:\dev_libs\langchain\venv\lib\site-packages\pydantic\v1\main.py", line 341, in __init__
    raise validation_error
pydantic.v1.error_wrappers.ValidationError: 1 validation error for jointSchemaSchema
b
  field required (type=value_error.missing)
>>> # SOLUTION IS AS FOLLOWS
>>> joint_tool.__config__.extra = pydantic.Extra.allow            
>>> def handle_validation_error_deco(func):
...     
...     @wraps(func)
...     def wrapped(*args, **kwargs):
...         try:
...             return func(*args, **kwargs)
...         except pydantic.ValidationError as e:
...             return str(e)
...     return wrapped
...

>>> joint_tool.run  = handle_validation_error_deco(joint_tool.run)
>>> joint_tool.run(dict(a='a'))
'1 validation error for jointSchemaSchema\nb\n  field required (type=value_error.missing)'
>>> type( joint_tool.run(dict(a='a'))) 
<class 'str'>

By the way, I think that this problem is essentially caused because ValidationError is not handled in BaseTool.run and BaseTool.arun:

So I believe one of the best solutions is as follows

mnicstruwig commented 7 months ago

@hmasdev Unfortunately I don't believe the proposed solution works (I think because Pydantic is particular about modifying fields after declaration):

>>> tool.run = handle_validation_error(tool.run)
ValueError: "StructuredTool" object has no field "run"

The fix to override the field appears to be to use the built-in object's __setattr__ data model method in order to force the override (this is the proposed answer from the link above):

So the workaround for others waiting for the PR to get merged would be to do:

object.__setattr__(joint_tool, "run", handle_validation_error_deco(joint_tool.run)
hmasdev commented 7 months ago

@mnicstruwig Thanks for alternative solution.

Just in case, I afraid that you forget to set the config before overwriting tool.run.

>>> tool.__config__.extra = pydantic.Extra.allow 
>>> tool.run = handle_validation_error(tool.run)

See https://github.com/pydantic/pydantic/blob/v1.10.13/pydantic/main.py#L356

But your solution is better :)

deepakagrawal commented 4 months ago

Sorry for looong delay in response. I ended up overwriting the _parse_input function of the tool

def _parse_input(
        self,
        # pyre-ignore
        tool_input: Union[str, Dict],
    ) -> Union[str, Dict[str, Any]]:
        """Convert tool input to pydantic model."""
        input_args = self.args_schema
        if isinstance(tool_input, str):
            if input_args is not None:
                key_ = next(iter(input_args.__fields__.keys()))
                input_args.validate({key_: tool_input})
            return tool_input
        else:
            if input_args is not None:
                try:
                    result = input_args.parse_obj(tool_input)
                    return {k: v for k, v in result.dict().items() if k in tool_input}
                except ValidationError as e:
                    logger.error(f"Failed to validate input: {e}")
                    return "Failed to validate input. I cannot use this tool. You should try another tool."
        return "Failed to validate input. I cannot use this tool. You should try another tool."
Bhargav-1997 commented 3 months ago

Hello, I have also encountered the same issue, where I want to perform the input validation and based on results either continue with tool execution or ask user for correct inputs.

I need your suggestion about will it be a good approach to have a separate LLM agent with set of instructions for input validation using Chain of Thoughts and based on validation from LLM decide on next step?

SinaArdehali commented 2 months ago

there seems to be no proper way of doing input validation in custom tools with Langchain 😞

mnicstruwig commented 2 months ago

@SinaArdehali, @Bhargav-1997, @deepakagrawal, I've found a lot of success using magentic (magentic) instead for this use-case, since it allows you to "pause" at the function calling step to do whatever you'd like with your input arguments.