huggingface / huggingface_hub

The official Python client for the Huggingface Hub.
https://huggingface.co/docs/huggingface_hub
Apache License 2.0
2.01k stars 531 forks source link

improve client.zero_shot_classification() #2340

Closed MoritzLaurer closed 3 months ago

MoritzLaurer commented 3 months ago

This PR does 3 things:

  1. It adds support for the hypothesis_template argument from the 0-shot pipeline, which is important for customizing 0-shot classification
  2. It removes the requirement for having at least 2 labels in the label list. This wasn't really necessary and there are cases where checking just one label with multi_label=True is useful.
  3. It updates the doc strings and adds an example

Tested it with a HF endpoint running the model deberta-v3-large-zeroshot-v2

Example for use with hypothesis_template:

from huggingface_hub import InferenceClient

client = InferenceClient()

output = client.zero_shot_classification(
    model="https://h9qyt7jenlitt7j6.us-east-1.aws.endpoints.huggingface.cloud",
    text="I really like our dinner and I'm very happy. I don't like the weather though.",
    labels=["positive", "negative", "pessimistic", "optimistic"],
    multi_label=True,
    hypothesis_template="This text is {} towards the weather"
)

# output:
ZeroShotClassificationOutputElement(label='negative', score=0.9231801629066467),
ZeroShotClassificationOutputElement(label='pessimistic', score=0.8760990500450134),
ZeroShotClassificationOutputElement(label='optimistic', score=0.0008674879791215062),
ZeroShotClassificationOutputElement(label='positive', score=0.0005250611575320363)

Example for use with single label:

client = InferenceClient()

output = client.zero_shot_classification(
    model="https://h9qyt7jenlitt7j6.us-east-1.aws.endpoints.huggingface.cloud",
    text="I really like our dinner and I'm very happy. I don't like the weather though.",
    labels=["positive"],
    multi_label=True,
    hypothesis_template="This text is {} towards the weather"
)

# output
[ZeroShotClassificationOutputElement(label='positive', score=0.0001598795352037996)]
HuggingFaceDocBuilderDev commented 3 months ago

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

MoritzLaurer commented 3 months ago

Thanks for the review @Wauplin ! I've applied your suggestions and ran make style. One weird thing is that running python utils/generate_async_inference_client.py --update removes references to async from the async client docstring (see this commit). I've first added references to async manually, but then the checks didn't pass anymore. Just flagging this in case that's unintended

MoritzLaurer commented 3 months ago

great, thanks a lot for your review @Wauplin !