run-llama / llama_index

LlamaIndex is a data framework for your LLM applications
https://docs.llamaindex.ai
MIT License
33.39k stars 4.68k forks source link

[Bug]: When there are curly braces in the prompt template, such as have JSON, it will result in a formatting error. #14310

Open danerlt opened 2 weeks ago

danerlt commented 2 weeks ago

Bug Description

I have written a test code named test_have_json.py, with the following content:

from llama_index.core.base.llms.types import ChatMessage, MessageRole
from llama_index.core.prompts import (
    PromptTemplate,
)

def test_template_hava_json() -> None:
    """Test partial format."""
    prompt_txt = 'hello {text} {foo} \noutput format:\n```json\n{"name": "llamaindex"}\n```'
    except_prompt = 'hello world bar \noutput format:\n```json\n{"name": "llamaindex"}\n```'
    print(prompt_txt)

    prompt_template = PromptTemplate(prompt_txt)

    prompt_fmt = prompt_template.partial_format(foo="bar")
    prompt = prompt_fmt.format(text="world")
    print(prompt)
    assert isinstance(prompt_fmt, PromptTemplate)

    assert prompt == except_prompt

    assert prompt_fmt.format_messages(text="world") == [
        ChatMessage(content=except_prompt, role=MessageRole.USER)
    ]

When I was using pytest for testing, an error occurred: KeyError: '"name"'.

Version

0.10.48.post1

Steps to Reproduce

run test code

Relevant Logs/Tracbacks

pytest traceback:

pytest test_have_json.py 
============================================================== test session starts ===============================================================
platform linux -- Python 3.10.14, pytest-8.2.2, pluggy-1.5.0
rootdir: /workspace/ai/tests
plugins: anyio-4.2.0
collected 1 item                                                                                                                                 

test_have_json.py F                                                                                                                        [100%]

==================================================================== FAILURES ====================================================================
____________________________________________________________ test_template_hava_json _____________________________________________________________

    def test_template_hava_json() -> None:
        """Test partial format."""
        prompt_txt = 'hello {text} {foo} \noutput format:\n\n{"name": "llamaindex"}\n```'
        except_prompt = 'hello world bar \noutput format:\n```json\n{"name": "llamaindex"}\n```'
        print(prompt_txt)

        prompt_template = PromptTemplate(prompt_txt)

        prompt_fmt = prompt_template.partial_format(foo="bar")
>       prompt = prompt_fmt.format(text="world")

test_have_json.py:25: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = PromptTemplate(metadata={'prompt_type': <PromptType.CUSTOM: 'custom'>}, template_vars=['text', 'foo', '"name"'], kwarg...ngs=None, function_mappings=None, template='hello {text} {foo} \noutput format:\n```json\n{"name": "llamaindex"}\n```')
completion_to_prompt = None, kwargs = {'text': 'world'}, all_kwargs = {'foo': 'bar', 'text': 'world'}
mapped_all_kwargs = {'foo': 'bar', 'text': 'world'}

    def format(
        self,
        llm: Optional[BaseLLM] = None,
        completion_to_prompt: Optional[Callable[[str], str]] = None,
        **kwargs: Any,
    ) -> str:
        """Format the prompt into a string."""
        del llm  # unused
        all_kwargs = {
            **self.kwargs,
            **kwargs,
        }

        mapped_all_kwargs = self._map_all_vars(all_kwargs)
>       prompt = self.template.format(**mapped_all_kwargs)
E       KeyError: '"name"'

/root/miniconda3/envs/ai/lib/python3.10/site-packages/llama_index/core/prompts/base.py:196: KeyError
-------------------------------------------------------------- Captured stdout call --------------------------------------------------------------
hello {text} {foo} 
output format:
```json
{"name": "llamaindex"}

============================================================ short test summary info ============================================================= FAILED test_have_json.py::test_template_hava_json - KeyError: '"name"' =============================================================== 1 failed in 3.08s ================================================================

dosubot[bot] commented 2 weeks ago

To fix the KeyError: '"name"' that occurs when running the test_template_hava_json test with pytest due to curly braces in the prompt template, you need to ensure that all placeholders in the template are properly formatted and provided with corresponding values during the test execution.

Here is an example of how you can define and test a template with curly braces:

from llama_index.core.base.llms.types import ChatMessage, MessageRole
from llama_index.core.prompts import PromptTemplate
from llama_index.legacy.prompts.utils import get_template_vars

def test_template_hava_json() -> None:
    """Test template with curly braces."""
    prompt_txt = 'hello {text} {foo} \noutput format:\n```json\n{"name": "llamaindex"}\n```'
    except_prompt = 'hello world bar \noutput format:\n```json\n{"name": "llamaindex"}\n```'
    print(prompt_txt)

    prompt_template = PromptTemplate(prompt_txt)

    # Ensure all placeholders are provided
    template_vars = get_template_vars(prompt_txt)
    assert template_vars == ["text", "foo"]

    prompt_fmt = prompt_template.partial_format(foo="bar")
    assert isinstance(prompt_fmt, PromptTemplate)

    prompt = prompt_fmt.format(text="world")
    print(prompt)
    assert prompt == except_prompt

    assert prompt_fmt.format_messages(text="world") == [
        ChatMessage(content=except_prompt, role=MessageRole.USER)
    ]

In this example, the get_template_vars function is used to extract the template variables, ensuring that all placeholders are identified. The partial_format method is then used to provide a value for foo, and the format method is used to provide a value for text. This ensures that all placeholders are properly formatted, preventing the KeyError [1][2].

To continue talking to Dosu, mention @dosu.

danerlt commented 2 weeks ago

可以参考Langchain中的Prompt,添加使用Jinja2格式化的功能:


class PromptTemplate(StringPromptTemplate):
    """Prompt template for a language model.

    A prompt template consists of a string template. It accepts a set of parameters
    from the user that can be used to generate a prompt for a language model.

    The template can be formatted using either f-strings (default) or jinja2 syntax.

    *Security warning*: Prefer using `template_format="f-string"` instead of
        `template_format="jinja2"`, or make sure to NEVER accept jinja2 templates
        from untrusted sources as they may lead to arbitrary Python code execution.

        As of LangChain 0.0.329, Jinja2 templates will be rendered using
        Jinja2's SandboxedEnvironment by default. This sand-boxing should
        be treated as a best-effort approach rather than a guarantee of security,
        as it is an opt-out rather than opt-in approach.

        Despite the sand-boxing, we recommend to never use jinja2 templates
        from untrusted sources.

    Example:

        .. code-block:: python

            from langchain_core.prompts import PromptTemplate

            # Instantiation using from_template (recommended)
            prompt = PromptTemplate.from_template("Say {foo}")
            prompt.format(foo="bar")

            # Instantiation using initializer
            prompt = PromptTemplate(template="Say {foo}")
    """

    @classmethod
    def from_template(
        cls,
        template: str,
        *,
        template_format: str = "f-string",
        partial_variables: Optional[Dict[str, Any]] = None,
        **kwargs: Any,
    ) -> PromptTemplate:
        """Load a prompt template from a template.

        *Security warning*: Prefer using `template_format="f-string"` instead of
            `template_format="jinja2"`, or make sure to NEVER accept jinja2 templates
            from untrusted sources as they may lead to arbitrary Python code execution.

            As of LangChain 0.0.329, Jinja2 templates will be rendered using
            Jinja2's SandboxedEnvironment by default. This sand-boxing should
            be treated as a best-effort approach rather than a guarantee of security,
            as it is an opt-out rather than opt-in approach.

            Despite the sand-boxing, we recommend to never use jinja2 templates
            from untrusted sources.

        Args:
            template: The template to load.
            template_format: The format of the template. Use `jinja2` for jinja2,
                             and `f-string` or None for f-strings.
            partial_variables: A dictionary of variables that can be used to partially
                               fill in the template. For example, if the template is
                              `"{variable1} {variable2}"`, and `partial_variables` is
                              `{"variable1": "foo"}`, then the final prompt will be
                              `"foo {variable2}"`.

        Returns:
            The prompt template loaded from the template.
        """

        input_variables = get_template_variables(template, template_format)
        _partial_variables = partial_variables or {}

        if _partial_variables:
            input_variables = [
                var for var in input_variables if var not in _partial_variables
            ]

        return cls(
            input_variables=input_variables,
            template=template,
            template_format=template_format,  # type: ignore[arg-type]
            partial_variables=_partial_variables,
            **kwargs,
        )
tibor-reiss commented 4 days ago

See comment in #14311:

Solution is to double up the braces: prompt_txt = 'hello {text} {foo} \noutput format:\n```json\n{{"name": "llamaindex"}}\n```' Docs: https://docs.python.org/3/library/string.html#formatstrings

danerlt commented 4 days ago

@tibor-reiss Thank you for your response. I understand that this issue pertains to the standard library, but I prefer not to alter the JSON format with double braces. Instead, I would like to address it using jinja2, similar to how the langchain framework handles such matters.