Closed paopa closed 5 months ago
@paopa I experimented Hamilton a little bit, and found we can insert all inputs of the pipeline in the inputs
argument for self._pipe.execute
so basically the code becomes this as follows. Now I think the code is cleaner and we don't need global variables now
for the details and the related PR, please check out: https://github.com/Canner/WrenAI/pull/363
parts of the modified code in query_understanding_pipeline.py
_prompt = """
### TASK ###
Based on the user's input below, classify whether the query is not random words.
Provide your classification as 'Yes' or 'No'. Yes if you think the query is not random words, and No if you think the query is random words.
### FINAL ANSWER FORMAT ###
The final answer must be the JSON format like following:
{
"result": "yes" or "no"
}
### INPUT ###
{{ query }}
Let's think step by step.
"""
@component
class QueryUnderstandingPostProcessor:
@component.output_types(
is_valid_query=bool,
)
def run(self, replies: List[str]):
try:
result = orjson.loads(replies[0])["result"]
if result == "yes":
return {
"is_valid_query": True,
}
return {
"is_valid_query": False,
}
except Exception as e:
logger.error(f"Error in QueryUnderstandingPostProcessor: {e}")
return {
"is_valid_query": True,
}
def prompt(query: str, prompt_builder: PromptBuilder) -> dict:
logger.debug(f"query: {query}")
return prompt_builder.run(query=query)
async def generate(prompt: dict, generator: Any) -> dict:
logger.debug(f"prompt: {prompt}")
return await generator.run(prompt=prompt.get("prompt"))
def post_process(generate: dict, post_processor: QueryUnderstandingPostProcessor) -> dict:
logger.debug(f"generate: {generate}")
return post_processor.run(generate.get("replies"))
class QueryUnderstanding(BasicPipeline):
def __init__(
self,
llm_provider: LLMProvider,
):
self.generator = llm_provider.get_generator()
self.prompt_builder = PromptBuilder(template=_prompt)
self.post_processor = QueryUnderstandingPostProcessor()
super().__init__(
AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult())
)
@timer
async def run(
self,
query: str,
):
logger.info("Ask QueryUnderstanding pipeline is running...")
return await self._pipe.execute(["post_process"], inputs={
"query": query,
"generator": self.generator,
"prompt_builder": self.prompt_builder,
"post_processor": self.post_processor,
})
@paopa I experimented Hamilton a little bit, and found we can insert all inputs of the pipeline in the
inputs
argument forself._pipe.execute
so basically the code becomes this as follows. Now I think the code is cleaner and we don't need global variables nowfor the details and the related PR, please check out: #363
parts of the modified code in
query_understanding_pipeline.py
...
Nice idea! I had considered using this approach when refactoring the first pipeline, but at the time, I preferred keeping the input limited to data only. However, thinking about it again, I now believe that using the component as the input is a good idea.
We merged the changes into #363, thus we could close this PR.
The following PR of #316. and this one aims to refactor the indexing pipeline.