bab2min / tomotopy

Python package of Tomoto, the Topic Modeling Tool
https://bab2min.github.io/tomotopy
MIT License
548 stars 62 forks source link

holdout perplexity #156

Open PearlOnyx08 opened 2 years ago

PearlOnyx08 commented 2 years ago

It appears that the perplexity used mdl.perplexity is fixed to the training set. Would it be possible to add a function to calculate perplexity on a pre-defined holdout set?

bab2min commented 2 years ago

Hi @PearlOnyx08 Thank you for your good suggestion. You can use mdl.infer() to calculate holdout perplexity, but it is somewhat cumbersome. In my opinion, even adding a method for holdout perplexity won't reduce the inconvenience much. Thus it seems more useful to add a wrapper for training that systematically performs the entire training process, rather than simply adding a method that calculates holdout perplexity like:

train_set = Corpus(...)
eval_set = Corpus(...)
model = tp.LDAModel(...)
trainer = tp.trainer.Trainer(model, train_set, eval_set, max_iteration=1000, logging_interval=100)
# setting observers for hold-in and hold-out perplexity, early-stopping condition, etc.
# you can log hold-in and hold-out perplexity, do plotting or whatever else using observers
# ...
trainer.start() # do training

For detailed implementation, it would be good to follow huggingface's Trainer class.

lucasgautheron commented 2 years ago

Hi @bab2min, regarding evaluating perplexity from mdl.infer(), how would you do that? My attempt attempt to reproduce the value of mdl.perplexity for the training set busing the log likelihood returned by mdl.infer() failed:

mdl = tp.CTModel(tw=tp.TermWeight.ONE, corpus=training_corpus, k=k)
mdl.train(n)

res, total_ll = mdl.infer(training_corpus)
words = np.array([len(doc.words) for doc in res])

perplexity = np.exp(-np.sum(total_ll)/np.sum(words))
print(perplexity, mdl.perplexity) # values differ!

I guess this means that the returned log likelihood is not what I thought it is. What is is exactly? What should I do to make the calculation correct?

Thank you for your help!

bab2min commented 2 years ago

Hi @lucasgautheron , A detailed discussion related to your question is here(#147). Currently there are three options for calculating perplexity in tomotopy (mdl.perplexity, mdl.infer(), doc.get_ll()), but they are using subtly different calculation formulas each. I think this is a design mistake of tomotopy, so I am working on unifying it systematically(#148, #149).

The problem in your code is that mdl.infer() actually returns sum of doc-topic ll and topic-term ll. So if you have n documents, np.sum(total_ll) actually includes n duplicate topic-term ll values. To avoid this problem, you should pass together=True parameters to mdl.infer(), which forces to infer all document at the same time(not independently for each document) and return the sum of ll for all documents without duplication. However, even if this is done, the value may be slightly different from the value of mdl.perplexity. This is because mdl.infer() also assigns a topic to each word by a probabilistic process like mdl.train().

lucasgautheron commented 2 years ago

Hi,

Thank you very much for your answer, this is super clear now!