oliverguhr / german-sentiment-lib

An easy to use python package for deep learning-based german sentiment classification.
https://pypi.org/project/germansentiment/
MIT License
58 stars 7 forks source link

Feature Request: Return confidence of prediction #9

Closed Max-Jesch closed 2 years ago

Max-Jesch commented 3 years ago

Pretty awesome project. Thanks a lot for sharing.

One Feature that would be really valuable to me would be to get some sort of "confidence" for the predictions. Do you think that is tricky to do? I would offer my help if you think that makes sense.

oliverguhr commented 3 years ago

I am glad you like it. You totally can do this - just modify this line, to get the value of the logit:

https://github.com/oliverguhr/german-sentiment-lib/blob/master/germansentiment/sentimentmodel.py#L32

If you decide to change the code I would appreciate a pull request :)

bpfrd commented 2 years ago

I needed the same thing and I overwrite the predict method in the class:

from typing import List
import torch
from germansentiment import SentimentModel

class SentimentModel(SentimentModel):
    def __init__(self):
        super().__init__()

    def predict_sentiment_proba(self, texts: List[str])-> List[str]:
        texts = [self.clean_text(text) for text in texts]
        # Add special tokens takes care of adding [CLS], [SEP], <s>... tokens in the right way for each model.
        # truncation=True limits number of tokens to model's limitations (512)
        encoded = self.tokenizer.batch_encode_plus(texts, padding=True, add_special_tokens=True,truncation=True, return_tensors="pt")
        encoded = encoded.to(self.device)
        with torch.no_grad():
                logits = self.model(**encoded)

        #label_ids = torch.argmax(logits[0], axis=1)
        return [[i.item() for i in r] for r in torch.nn.Softmax(dim=1)(logits[0])], self.model.config.id2label
oliverguhr commented 2 years ago

I added an API feature that does this with version 1.1.0

from germansentiment import SentimentModel

model = SentimentModel()

classes, probabilities = model.predict_sentiment(["das ist super"], output_probabilities = True) 
print(classes, probabilities)
['positive'] [[['positive', 0.9761366844177246], ['negative', 0.023540444672107697], ['neutral', 0.00032294404809363186]]]