huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
133.88k stars 26.78k forks source link

KeyError: 'logits' #17880

Closed kkavyashankar0009 closed 2 years ago

kkavyashankar0009 commented 2 years ago

System Info

`transformers` version: 4.16.2
- Platform: Linux-5.13.0-48-generic-x86_64-with-glibc2.31
- Python version: 3.9.7
- PyTorch version (GPU?): 1.9.1+cu111 (True)
- Tensorflow version (GPU?): 2.4.1 (False)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: <fill in>
- Using distributed or parallel set-up in script?: <fill in>

Who can help?

@Narsil

Information

Tasks

Reproduction

bertname = 'bert-base-cased' bert = AutoModel.from_pretrained(bertname) tokenizer= AutoTokenizer.from_pretrained(bert_name)

classifier = pipeline("zero-shot-classification",model=bert,tokenizer=tokenizer)

for d in tqdm(data_loader):

  text=d['text']
  true_label = d["label"]

  for i in range(len(text)):
    tl=c.index(true_label[i])
    Ground_Truth.append(tl)
    output=classifier(text[i],label)
    print('output',output)
    high_score=max(output['scores'])

Error::: File "/home/kshankar/Desktop/Project/Zero_Shot_updated/Fine-tuning/BBC_distilbert-base-uncased-finetuned-sst-2-english.py", line 187, in eval_model output=classifier(text[i],label) File "/home/kshankar/miniconda3/lib/python3.9/site-packages/transformers/pipelines/zero_shot_classification.py", line 182, in call return super().call(sequences, kwargs) File "/home/kshankar/miniconda3/lib/python3.9/site-packages/transformers/pipelines/base.py", line 1006, in call return self.run_single(inputs, preprocess_params, forward_params, postprocess_params) File "/home/kshankar/miniconda3/lib/python3.9/site-packages/transformers/pipelines/base.py", line 1030, in run_single outputs = self.postprocess(all_outputs, postprocess_params) File "/home/kshankar/miniconda3/lib/python3.9/site-packages/transformers/pipelines/zero_shot_classification.py", line 214, in postprocess logits = np.concatenate([output["logits"].numpy() for output in model_outputs]) File "/home/kshankar/miniconda3/lib/python3.9/site-packages/transformers/pipelines/zero_shot_classification.py", line 214, in logits = np.concatenate([output["logits"].numpy() for output in model_outputs]) KeyError: 'logits'

Expected behavior

logits is assigned before assignment
NielsRogge commented 2 years ago

Hi,

You're loading the pipeline with a BertModel, which doesn't include a head on top (like a sequence classification head for instance). Hence, no logits are computed.

The zero-shot classification pipeline makes use of sequence classifiers fine-tuned on an NLI task (natural language inference). Hence, you'll need to provide an xxxForSequenceClassification model fine-tuned on such a dataset.

github-actions[bot] commented 2 years ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

kkavyashankar0009 commented 2 years ago

Hi,

You're loading the pipeline with a BertModel, which doesn't include a head on top (like a sequence classification head for instance). Hence, no logits are computed.

The zero-shot classification pipeline makes use of sequence classifiers fine-tuned on an NLI task (natural language inference). Hence, you'll need to provide an xxxForSequenceClassification model fine-tuned on such a dataset.

Thank you for the response!

It's the same even if I use the pre-trained zero-shot classification model from Huggingface.

Example: bert_name = ‘facebook/bart-large-mnli’ model = AutoModel.from_pretrained(bert_name) tokenizer = AutoTokenizer.from_pretrained(bert_name) classifier = pipeline(“zero-shot-classification”, model = model , tokenizer=tokenizer)

Narsil commented 2 years ago

You need to replace AutoModel with AutoModelForSequenceClassification and use a model that supports AutoModelForSequenceClassification.

Or use directly

pipe = pipeline(model="facebook/bart-large-mnli")
print(pipe("Is this ok?", candidate_labels=["Science", "politics"])
kkavyashankar0009 commented 2 years ago

You need to replace AutoModel with AutoModelForSequenceClassification and use a model that supports AutoModelForSequenceClassification.

Or use directly

pipe = pipeline(model="facebook/bart-large-mnli")
print(pipe("Is this ok?", candidate_labels=["Science", "politics"])

Its working. Thanks alot.

akashAD98 commented 7 months ago

im getting the same error for bart mnist model

#!pip install torch==2.1.2
#!pip install --upgrade-strategy eager install optimum[onnxruntime]

!optimum-cli export onnx  --task zero-shot-classification --model facebook/bart-large-mnli bart-large-mnli_onnx_zs_model/

from optimum.onnxruntime import ORTModelForQuestionAnswering
from transformers import AutoTokenizer,pipeline
# for sentiment 
tokenizer = AutoTokenizer.from_pretrained("bart-large-mnli_onnx_zs_model")
model = ORTModelForQuestionAnswering.from_pretrained("bart-large-mnli_onnx_zs_model")

onnx_z0 = pipeline("zero-shot-classification", model=model, tokenizer=tokenizer)

sequence_to_classify = "Who are you voting for in 2020?"
candidate_labels = ["Europe", "public health", "politics", "elections"]
pred = onnx_z0(sequence_to_classify, candidate_labels)
pred

error:


KeyError Traceback (most recent call last) in <cell line: 10>() 8 sequence_to_classify = "Who are you voting for in 2020?" 9 candidate_labels = ["Europe", "public health", "politics", "elections"] ---> 10 pred = onnx_z0(sequence_to_classify, candidate_labels) 11 pred

7 frames /usr/local/lib/python3.10/dist-packages/optimum/onnxruntime/modeling_ort.py in forward(self, input_ids, attention_mask, token_type_ids, **kwargs) 1258 outputs = self.model.run(None, onnx_inputs) 1259 -> 1260 start_logits = outputs[self.output_names["start_logits"]] 1261 end_logits = outputs[self.output_names["end_logits"]] 1262 if use_torch:

KeyError: 'start_logits'