huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
131.97k stars 26.29k forks source link

pipeline 'text-classification' in >=4.40.0 throwing TypeError: Got unsupported ScalarType BFloat16 #30542

Closed derekelewis closed 1 month ago

derekelewis commented 4 months ago

System Info

Who can help?

No response

Information

Tasks

Reproduction

Does not occur in 4.39.3 - happens in >=4.40.0 and main. Appears to be related to PR #30518.

Test code is below (please ignore the lack of using a pre-trained sequence classification model):

model_id = "google/gemma-2b"

from transformers import AutoTokenizer, AutoModelForSequenceClassification

tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token_id = tokenizer.eos_token_id

import torch

model = AutoModelForSequenceClassification.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    label2id={"LABEL_0": 0, "LABEL_1": 1},
    num_labels=2,
)

from transformers import pipeline

classification_pipeline = pipeline("text-classification", model=model, tokenizer=tokenizer, return_all_scores=True)

predictions = classification_pipeline("test")

Traceback:

{
    "name": "TypeError",
    "message": "Got unsupported ScalarType BFloat16",
    "stack": "---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[5], line 1
----> 1 predictions = pipeline(\"test\")

File ~/miniconda3/envs/mlenv/lib/python3.10/site-packages/transformers/pipelines/text_classification.py:156, in TextClassificationPipeline.__call__(self, inputs, **kwargs)
    122 \"\"\"
    123 Classify the text(s) given as inputs.
    124 
   (...)
    153     If `top_k` is used, one such dictionary is returned per label.
    154 \"\"\"
    155 inputs = (inputs,)
--> 156 result = super().__call__(*inputs, **kwargs)
    157 # TODO try and retrieve it in a nicer way from _sanitize_parameters.
    158 _legacy = \"top_k\" not in kwargs

File ~/miniconda3/envs/mlenv/lib/python3.10/site-packages/transformers/pipelines/base.py:1242, in Pipeline.__call__(self, inputs, num_workers, batch_size, *args, **kwargs)
   1234     return next(
   1235         iter(
   1236             self.get_iterator(
   (...)
   1239         )
   1240     )
   1241 else:
-> 1242     return self.run_single(inputs, preprocess_params, forward_params, postprocess_params)

File ~/miniconda3/envs/mlenv/lib/python3.10/site-packages/transformers/pipelines/base.py:1250, in Pipeline.run_single(self, inputs, preprocess_params, forward_params, postprocess_params)
   1248 model_inputs = self.preprocess(inputs, **preprocess_params)
   1249 model_outputs = self.forward(model_inputs, **forward_params)
-> 1250 outputs = self.postprocess(model_outputs, **postprocess_params)
   1251 return outputs

File ~/miniconda3/envs/mlenv/lib/python3.10/site-packages/transformers/pipelines/text_classification.py:205, in TextClassificationPipeline.postprocess(self, model_outputs, function_to_apply, top_k, _legacy)
    202         function_to_apply = ClassificationFunction.NONE
    204 outputs = model_outputs[\"logits\"][0]
--> 205 outputs = outputs.numpy()
    207 if function_to_apply == ClassificationFunction.SIGMOID:
    208     scores = sigmoid(outputs)

TypeError: Got unsupported ScalarType BFloat16"
}

Expected behavior

predictions = classification_pipeline("test") should return predictions.

amyeroberts commented 4 months ago

cc @ArthurZucker

iseesaw commented 3 months ago

same question!

ArthurZucker commented 3 months ago

This is related to #28109. But ys a bit weird that we go to numpy(). Do you want to open a PR for this?

ccmilne commented 3 months ago

Any update on this? Getting the same result when using the pipeline class for question answering

ArthurZucker commented 3 months ago

A fix was merged and will be included in the release

mnoukhov commented 2 months ago

@ArthurZucker which PR fixes this issue?

AIR-hl commented 2 months ago

same problem

github-actions[bot] commented 1 month ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

ArthurZucker commented 1 month ago

30999 is the fix!