lightonai / RITA

RITA is a family of autoregressive protein models, developed by LightOn in collaboration with the OATML group at Oxford and the Debora Marks Lab at Harvard.
MIT License
88 stars 8 forks source link

How to obtain perplexity evaluation datasets? #11

Open LGH1gh opened 1 year ago

LGH1gh commented 1 year ago

Dear Author,

Thanks for releasing the RITA for protein generation! However, I wonder how can I obtain perplexity evalutation datasets used in your paper and how to calculate perplexity. Hope for your suggestions. Thanks in advance!

Detopall commented 2 months ago

You can use the following code to calculate the perplexity. Can't really help you with obtaining perplexity evalutation of the datasets used in their paper

import math
import torch

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import pipeline

model = AutoModelForCausalLM.from_pretrained("lightonai/RITA_s", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("lightonai/RITA_s")

rita_gen = pipeline('text-generation', model=model, tokenizer=tokenizer)
sequences = rita_gen("MAB", max_length=200, do_sample=True, top_k=950, repetition_penalty=1.2, 
                     num_return_sequences=2, eos_token_id=2)

def calculatePerplexity(sequence, model, tokenizer):
    input_ids = torch.tensor(tokenizer.encode(sequence)).unsqueeze(0) 
    input_ids = input_ids.to(model.device)

    with torch.no_grad():
        outputs = model(input_ids, labels=input_ids)
    loss, logits = outputs[:2]

    return math.exp(loss)

for seq in sequences:
    print(f"seq: {seq['generated_text'].replace(' ', '')}")
    ppl = calculatePerplexity(seq['generated_text'], model, tokenizer)
    print(f"Perplexity: {ppl}\n")

With these results:

seq: MABVVGTALYPGSDRFDGEYEVDIVIDTDGARYVLPVINTITHVKQGTSTRHPLGKAGQARKYATMHTGNLVLHLFDKGHTGVSIHGTSIDERIFGADGRVIAEAQGSGDMRHYGISPNRVAVCVARPFGGEGFSVPLSIHALGNETGVQTTGSGDVSTTSAVEGPAQEQMGFLDHTLSYASSTILTYRTQVTTGLGGAR
Perplexity: 132566.77587907546

seq: MABPVVTREPGVYFLAPRVSKFYEIIPWWNEMYVIECSIVSAAAGAPAVTPIQIRAPDVDIMSQVTSTAGMTAFVKVKRSRVIKMYQRVEPVERLHALVGGASILLDASLPQAALVTIEGGDIFEVFHGTEGLLAIIDGAIQQGLFSYKM
Perplexity: 127686.55561821107

The lower the perplexity score the better. The lower perplexity of the second sequence suggests that it is more coherent and natural-sounding according to the language model, and is likely a better-quality sequence compared to the first one.

Hope this helps.