elixir-nx / bumblebee

Pre-trained Neural Network models in Axon (+ 🤗 Models integration)
Apache License 2.0
1.26k stars 90 forks source link

Downcase label before checking equality #327

Closed preciz closed 4 months ago

preciz commented 4 months ago

Sometimes labels are in uppercase. So the library raises (ArgumentError) expected model specification to include "entailment" label in :id_to_label when building the zero shot classification.

For example here: https://huggingface.co/typeform/distilbert-base-uncased-mnli/blob/main/config.json

I tested this commit and it allows me to use the above mentioned model.

The code I used to test this commit:

Nx.default_backend({EXLA.Backend, []})

hf_model = "typeform/distilbert-base-uncased-mnli"

{:ok, model} = Bumblebee.load_model({:hf, hf_model})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "distilbert/distilbert-base-uncased"})

serving = Bumblebee.Text.zero_shot_classification(model, tokenizer, ["Is clothes", "Is food"])

Nx.Serving.run(serving, "Steak is tasty")

Thank you for reviewing.