Canner / WrenAI

🚀 An open-source SQL AI (Text-to-SQL) Agent that empowers data, product teams to chat with their data. 🤘
https://getwren.ai/oss
GNU Affero General Public License v3.0
2.04k stars 211 forks source link

feature(wren-ai-service): refactor indexing pipeline with Hamilton #345

Closed paopa closed 5 months ago

paopa commented 5 months ago

The following PR of #316. and this one aims to refactor the indexing pipeline.

cyyeh commented 5 months ago

@paopa I experimented Hamilton a little bit, and found we can insert all inputs of the pipeline in the inputs argument for self._pipe.execute so basically the code becomes this as follows. Now I think the code is cleaner and we don't need global variables now

for the details and the related PR, please check out: https://github.com/Canner/WrenAI/pull/363

parts of the modified code in query_understanding_pipeline.py

_prompt = """
### TASK ###
Based on the user's input below, classify whether the query is not random words.
Provide your classification as 'Yes' or 'No'. Yes if you think the query is not random words, and No if you think the query is random words.

### FINAL ANSWER FORMAT ###
The final answer must be the JSON format like following:

{
    "result": "yes" or "no"
}

### INPUT ###
{{ query }}

Let's think step by step.
"""

@component
class QueryUnderstandingPostProcessor:
    @component.output_types(
        is_valid_query=bool,
    )
    def run(self, replies: List[str]):
        try:
            result = orjson.loads(replies[0])["result"]

            if result == "yes":
                return {
                    "is_valid_query": True,
                }

            return {
                "is_valid_query": False,
            }
        except Exception as e:
            logger.error(f"Error in QueryUnderstandingPostProcessor: {e}")

            return {
                "is_valid_query": True,
            }

def prompt(query: str, prompt_builder: PromptBuilder) -> dict:
    logger.debug(f"query: {query}")
    return prompt_builder.run(query=query)

async def generate(prompt: dict, generator: Any) -> dict:
    logger.debug(f"prompt: {prompt}")
    return await generator.run(prompt=prompt.get("prompt"))

def post_process(generate: dict, post_processor: QueryUnderstandingPostProcessor) -> dict:
    logger.debug(f"generate: {generate}")
    return post_processor.run(generate.get("replies"))

class QueryUnderstanding(BasicPipeline):
    def __init__(
        self,
        llm_provider: LLMProvider,
    ):
        self.generator = llm_provider.get_generator()
        self.prompt_builder = PromptBuilder(template=_prompt)
        self.post_processor = QueryUnderstandingPostProcessor()

        super().__init__(
            AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult())
        )

    @timer
    async def run(
        self,
        query: str,
    ):
        logger.info("Ask QueryUnderstanding pipeline is running...")
        return await self._pipe.execute(["post_process"], inputs={
            "query": query,
            "generator": self.generator,
            "prompt_builder": self.prompt_builder,
            "post_processor": self.post_processor,
        })
paopa commented 5 months ago

@paopa I experimented Hamilton a little bit, and found we can insert all inputs of the pipeline in the inputs argument for self._pipe.execute so basically the code becomes this as follows. Now I think the code is cleaner and we don't need global variables now

for the details and the related PR, please check out: #363

parts of the modified code in query_understanding_pipeline.py

...

Nice idea! I had considered using this approach when refactoring the first pipeline, but at the time, I preferred keeping the input limited to data only. However, thinking about it again, I now believe that using the component as the input is a good idea.

dag

paopa commented 5 months ago

We merged the changes into #363, thus we could close this PR.