Closed Max-Jesch closed 2 years ago
I am glad you like it. You totally can do this - just modify this line, to get the value of the logit:
https://github.com/oliverguhr/german-sentiment-lib/blob/master/germansentiment/sentimentmodel.py#L32
If you decide to change the code I would appreciate a pull request :)
I needed the same thing and I overwrite the predict method in the class:
from typing import List
import torch
from germansentiment import SentimentModel
class SentimentModel(SentimentModel):
def __init__(self):
super().__init__()
def predict_sentiment_proba(self, texts: List[str])-> List[str]:
texts = [self.clean_text(text) for text in texts]
# Add special tokens takes care of adding [CLS], [SEP], <s>... tokens in the right way for each model.
# truncation=True limits number of tokens to model's limitations (512)
encoded = self.tokenizer.batch_encode_plus(texts, padding=True, add_special_tokens=True,truncation=True, return_tensors="pt")
encoded = encoded.to(self.device)
with torch.no_grad():
logits = self.model(**encoded)
#label_ids = torch.argmax(logits[0], axis=1)
return [[i.item() for i in r] for r in torch.nn.Softmax(dim=1)(logits[0])], self.model.config.id2label
I added an API feature that does this with version 1.1.0
from germansentiment import SentimentModel
model = SentimentModel()
classes, probabilities = model.predict_sentiment(["das ist super"], output_probabilities = True)
print(classes, probabilities)
['positive'] [[['positive', 0.9761366844177246], ['negative', 0.023540444672107697], ['neutral', 0.00032294404809363186]]]
Pretty awesome project. Thanks a lot for sharing.
One Feature that would be really valuable to me would be to get some sort of "confidence" for the predictions. Do you think that is tricky to do? I would offer my help if you think that makes sense.