Sinaptik-AI / pandas-ai

Chat with your database (SQL, CSV, pandas, polars, mongodb, noSQL, etc). PandasAI makes data analysis conversational using LLMs (GPT 3.5 / 4, Anthropic, VertexAI) and RAG.
https://pandas-ai.com
Other
13.56k stars 1.32k forks source link

`data_viz_library` can not not be set #1337

Open Yelinz opened 2 months ago

Yelinz commented 2 months ago

System Info

pandasai 2.2.14 python 3.12

🐛 Describe the bug

https://github.com/Sinaptik-AI/pandas-ai/blob/05431072676d44d409c6c95620c6f561370ec3ef/pandasai/pipelines/chat/prompt_generation.py#L49

The prompt always references a way to plot. This is not changeable to None or some other falsey value to trigger the exclusion in the template.

Should be able to pass None or similar to remove the sentence If you are asked to plot a chart, use "{{viz_lib}}" for charts, save as png. from the prompt

dosubot[bot] commented 2 months ago

Hey @Yelinz! I'm here to help you with any bugs, questions, or contributions you have for the pandas-ai repository. Let's squash those bugs together! 🤖

To address the issue where the data_viz_library parameter cannot be set to None or a falsy value to exclude the plotting instruction, you need to modify the prompt_generation.py file and the relevant prompt classes. Here are the steps to achieve this:

  1. Modify prompt_generation.py: Update the get_chat_prompt method to handle a falsy data_viz_library value and exclude the plotting instruction accordingly.
from typing import Any, Union

from pandasai.pipelines.logic_unit_output import LogicUnitOutput

from ...helpers.logger import Logger
from ...prompts.base import BasePrompt
from ...prompts.generate_python_code import GeneratePythonCodePrompt
from ...prompts.generate_python_code_with_sql import GeneratePythonCodeWithSQLPrompt
from ..base_logic_unit import BaseLogicUnit
from ..pipeline_context import PipelineContext

class PromptGeneration(BaseLogicUnit):
    """
    Code Prompt Generation Stage
    """

    pass

    def execute(self, input: Any, **kwargs) -> Any:
        """
        This method will return output according to
        Implementation.

        :param input: Your input data.
        :param kwargs: A dictionary of keyword arguments.
            - 'logger' (any): The logger for logging.
            - 'config' (Config): Global configurations for the test
            - 'context' (any): The execution context.

        :return: The result of the execution.
        """
        self.context: PipelineContext = kwargs.get("context")
        self.logger: Logger = kwargs.get("logger")

        prompt = self.get_chat_prompt(self.context)
        self.logger.log(f"Using prompt: {prompt}")

        return LogicUnitOutput(
            prompt,
            True,
            "Prompt Generated Successfully",
            {"content_type": "prompt", "value": prompt.to_string()},
        )

    def get_chat_prompt(self, context: PipelineContext) -> Union[str, BasePrompt]:
        # set viz_lib to None if data_viz_library is falsy
        viz_lib = context.config.data_viz_library if context.config.data_viz_library else None

        output_type = context.get("output_type")

        return (
            GeneratePythonCodeWithSQLPrompt(
                context=context,
                last_code_generated=context.get("last_code_generated"),
                viz_lib=viz_lib,
                output_type=output_type,
            )
            if context.config.direct_sql
            else GeneratePythonCodePrompt(
                context=context,
                last_code_generated=context.get("last_code_generated"),
                viz_lib=viz_lib,
                output_type=output_type,
            )
        )
  1. Modify GeneratePythonCodePrompt Class: Update the to_json method to conditionally include the viz_lib parameter.
def to_json(self):
    context = self.props["context"]
    viz_lib = self.props["viz_lib"]
    output_type = self.props["output_type"]
    memory = context.memory
    conversations = memory.to_json()

    system_prompt = memory.get_system_prompt()

    # prepare datasets
    datasets = [dataset.to_json() for dataset in context.dfs]

    config = {
        "direct_sql": context.config.direct_sql,
        "output_type": output_type,
    }

    if viz_lib:  # Only include viz_lib if it is not None or empty
        config["viz_lib"] = viz_lib

    return {
        "datasets": datasets,
        "conversation": conversations,
        "system_prompt": system_prompt,
        "prompt": self.to_string(),
        "config": config,
    }
  1. Modify the Template: Ensure the template conditionally includes the plotting instruction based on the viz_lib parameter.
{% if viz_lib %}
If you are asked to plot a chart, use "{{viz_lib}}" for charts, save as png.
{% endif %}

These changes will allow the exclusion of the plotting instruction by passing None or a falsy value to the data_viz_library parameter [1][2].

To continue talking to Dosu, mention @dosu.