pytorch / torcheval

A library that contains a rich collection of performant PyTorch model metrics, a simple interface to create new metrics, a toolkit to facilitate metric computation in distributed training and tools for PyTorch model evaluations.
https://pytorch.org/torcheval
Other
211 stars 46 forks source link

More precise definition of perplexity when ignore index is not None #154

Closed sh0416 closed 1 year ago

sh0416 commented 1 year ago

When we compute perplexity, we usually aggregate all token level nlls and divide by the number of tokens.

It is simple and straight forward to implement only when we attend all tokens into the computations.

If we allow ignore index, I think the implementation provides different values that is not the perplexity.

Let's think when the input ids has two different-length inputs.

16,22,14,-100,-100 88,74,69,87,77

And the token-level loss (I just masked out ignored tokens)

0.1,0.2,0.1,0.0,0.0 0.3,0.5,0.1,0.4,0.2

Then, the concept of ppl is defined in sequence-level, the perplexity for each sequence moght be the following.

Exp(0.4/3)=a Exp(1.5/5)=b

And then, reduce by mean.

(a+b)/2

However, to the best of my knowledge, your implementation would give us the following.

Exp((0.4+1.5)/8)

If you think the definition of perplexity when the ignore index is involved is the last one, it doesn't matter.

However, I think that the former is much way precise implementation of perplexity.

What do you think?

sh0416 commented 1 year ago

In the definition of wikipedia, the later is right.. I don't know why they compute perplexity as sentence level, but there has to be reason. Close issue.