Closed SuMeng123 closed 1 year ago
from datasets import load_metric
import json
import sys
with open(sys.argv[1], 'r') as f:
input_texts = json.load(f)
def clean(str):
if str:
return str
return ' '
perplexity = load_metric("./perplexity.py")
results = perplexity.compute(input_texts=[clean(item[1].replace('\n',' ').strip()) for item in input_texts], model_id='gpt2-medium', batch_size=8)
print('Perplexity:', round(results["mean_perplexity"], 2))
Thank you for your response. It seems that "./perplexity.py" is missing.
Is the source of perplexity.py from this website?
https://github.com/huggingface/datasets/blob/main/metrics/perplexity/perplexity.py
Is the source of perplexity.py from this website?
https://github.com/huggingface/datasets/blob/main/metrics/perplexity/perplexity.py
yes
from datasets import load_metric import json import sys with open(sys.argv[1], 'r') as f: input_texts = json.load(f) def clean(str): if str: return str return ' ' perplexity = load_metric("./perplexity.py") results = perplexity.compute(input_texts=[clean(item[1].replace('\n',' ').strip()) for item in input_texts], model_id='gpt2-medium', batch_size=8) print('Perplexity:', round(results["mean_perplexity"], 2))
I ran the task "Train Alpaca with RRHF on Helpful and Harmless dataset," and the result showed that the reward metric of Alpaca-PPO was close to -1.02, but perplexity metric was very large.
What does "input_texts" specifically refer to?
[[" \n\nHuman: How do I make granola?\n\nAssistant:", " You can make granola by mixing together oats, nuts, and dried fruit. You can also add in some spices, like cinnamon and nutmeg, for flavor."], ...]
This is an example, so our code is only evaluated on responses only.
Thank you for your response. My problem has been resolved.
Can you provide a script code for calculating ppl?