ai-forever / ru-gpts

Russian GPT3 models.
Apache License 2.0
2.08k stars 445 forks source link

Compute perplexity for text classification #79

Closed vladimirsvsv77 closed 2 years ago

vladimirsvsv77 commented 2 years ago

Hi! I am trying to use a model to classify texts. Found an article here.

It describes the get_answer (sentence1: str, sentence2: str)method that uses get_perp_num, a method that returns perplexity. Did I understand correctly that the implementation of the get_perp_num method may look something like this:

def calculate_perplexity(sentence, model, tokenizer):
    encodings = tokenizer(sentence, return_tensors='pt')
    input_ids = encodings.input_ids.to(device)
    with torch.no_grad():
        outputs = model(input_ids=input_ids, labels=input_ids)
    loss = outputs.loss
    return math.exp(loss[0].item() * input_ids.size(1))
TatianaShavrina commented 2 years ago

Hey @vladimirsvsv77 ,

absolutely. Checked the code I was using during the experiments:

from transformers.tokenization_utils import PreTrainedTokenizer

import torch
import math

def calc_ppl(phrase: str,
                        tokenizer: PreTrainedTokenizer,
                        model: GPT2LMHeadModel,
                        args) -> float:

    input_ids = [tokenizer(phrase)['input_ids']]
    with torch.no_grad():
        loss = tm(input_ids=input_ids, labels=input_ids).loss
    perplexity = math.exp(loss[0].item())
    return perplexity
vladimirsvsv77 commented 2 years ago

Got it, thanks!