deepset-ai / haystack

:mag: AI orchestration framework to build customizable, production-ready LLM applications. Connect components (models, vector DBs, file converters) to pipelines or agents that can interact with your data. With advanced retrieval methods, it's best suited for building RAG, question answering, semantic search or conversational agent chatbots.
https://haystack.deepset.ai
Apache License 2.0
17.66k stars 1.91k forks source link

Improve `TransformersQueryClassifier` #2587

Closed ZanSara closed 2 years ago

ZanSara commented 2 years ago

Problem Currently TransformersQueryClassifier is very closely build around the question/keywords/statement classifier model used in the tutorials ("hahrukhx01/bert-mini-finetune-question-detection). In practice, it can only handle models that output binary labels, one of which must be called LABEL_1.

I believe this limitations make it unsuitable for any other model than the one used in the example.

Solution HuggingFace how hosts a wide array of zero-shot text classification models which could be nicely applied to query classification, for example for sentiment/emotion analysis, or for topic classification. With limited changes, TransformersQueryClassifier can be improved to use these models effectively.

Note Currently it's possible to write custom nodes for this usecase. Here is an example.

from pprint import pprint
import logging
from typing import Optional, List, Any, Union

from transformers import pipeline

from haystack import Document, Pipeline, Answer
from haystack.nodes.base import BaseComponent
from haystack.schema import Document
from haystack.nodes import TransformersQueryClassifier

class ZeroshotQueryClassifier(TransformersQueryClassifier):

    outgoing_edges: int = 10

    def __init__(
        self,
        model_name_or_path: str,
        labels: List[str],
        model_version: Optional[str] = None,
        tokenizer: Optional[str] = None,
        use_gpu: bool = True,
        batch_size: Optional[int] = None,
    ):
        """
        :param model_name_or_path: accepts most zero-shot-classification model from HuggingFace (https://huggingface.co/models?pipeline_tag=zero-shot-classification)
        :param labels: the labels for zero-shot classification (for example a list of the emotions to classify by, or something like ["happy", "unhappy", "neutral"])
        """
        super().__init__(use_gpu=use_gpu, batch_size=batch_size)
        if tokenizer is None:
            tokenizer = model_name_or_path
        self.model = pipeline(
            task="zero-shot-classification", model=model_name_or_path, tokenizer=tokenizer, device= 0 if self.devices[0].type == "cuda" else -1, revision=model_version
        )
        self.labels = labels

    def _get_edge_number(self, label):
        return self.labels.index(label)+1

    def run(self, query: str) -> List[Document]:
        prediction = self.model([query], candidate_labels=self.labels, truncation=True)
        label = prediction[0]["labels"][0]
        return {"output": query}, f"output_{self._get_edge_number(label)}"

    def run_batch(self, queries: List[str]) -> Union[List[Document], List[List[Document]]]:
        predictions = self.model(queries, candidate_labels=self.labels, truncation=True)

        results = {f"output_{self._get_edge_number(label)}": {"queries": []} for label in self.labels}
        for query, prediction in zip(queries, predictions):
            label = prediction["labels"][0]
            results[f"output_{self._get_edge_number(label)}"]["queries"].append(query)

        return results, "split"

#
# Usage as a single node
#

query_classifier = ZeroshotQueryClassifier(
    model_name_or_path="typeform/distilbert-base-uncased-mnli", 
    labels=["happy", "unhappy", "neutral"]
)

queries = [
    "What's the answer?",
    "Would you be so kind to tell me the answer?",
    "Can you give me the right answer for once??",
]

# Processing all queries in a single call
output = query_classifier.run_batch(queries=queries)
print()
pprint(output)
print()

# Processing one query at a time
for query in queries:
    output = query_classifier.run(query=query)
    pprint(output)
    print()

#
# Usage in a pipeline (with stub nodes)
# 

class HappyAnswer(BaseComponent):

    def run(query: str):
        return {"answers": [Answer(answer="We're glad you like it!")]}, "output_1"

    def run_batch(queries: List[str]):
        return {"answers": [Answer(answer="We're glad you like it!")] * len(queries)}, "output_1"

class UnhappyAnswer(BaseComponent):

    def run(query: str):
        return {"answers": [Answer(answer="We're so sorry you're not happy :(")]}, "output_1"

    def run_batch(queries: List[str]):
        return {"answers": [Answer(answer="We're so sorry you're not happy :(")] * len(queries)}, "output_1"

class NeutralAnswer(BaseComponent):

    def run(query: str):
        return {"answers": [Answer(answer="Thanks for your feedback.")]}, "output_1"

    def run_batch(queries: List[str]):
        return {"answers": [Answer(answer="Thanks for your feedback.")] * len(queries)}, "output_1"

pipeline = Pipeline()
pipeline.add_node(component=query_classifier, name="classifier", inputs=["Query"])
pipeline.add_node(component=HappyAnswer(), name="happy", inputs=["classifier.output_1"])
pipeline.add_node(component=UnhappyAnswer(), name="unhappy", inputs=["classifier.output_2"])
pipeline.add_node(component=NeutralAnswer(), name="neutral", inputs=["classifier.output_3"])

pipeline.draw("pipeline.png")

for query in queries:
    output = pipeline.run(query=query)
    pprint(output)
    print()
anakin87 commented 2 years ago

@ZanSara just one question to better understand your opinion... Should we make TransformersQueryClassifier more general and suitable for handling non-binary output labels?

ZanSara commented 2 years ago

Hey @anakin87! Yes that's the aim. Right now the node is highly tailored for the specific model I mentioned above, so even other binary models would not work. First of all, I think any binary model should work as a QueryClassifier, but honestly I think it's worth to take this occasion to really improve it. if it was able to handle a generic text classification model it would be really cool :blush:

By the way: feel free to go for a heavy rewrite if you believe it's a good call. Just make sure that is still compatible with the tutorial.