MilaNLProc / contextualized-topic-models

A python package to run contextualized topic modeling. CTMs combine contextualized embeddings (e.g., BERT) with topic models to get coherent topics. Published at EACL and ACL 2021 (Bianchi et al.).
MIT License
1.2k stars 145 forks source link

CTM Memory and Speed Improvements #124

Closed Supermaxman closed 1 year ago

Supermaxman commented 1 year ago

I regularly ran out of memory on large datasets during the CTM fit call. Upon further inspection, I found the automatic generation of training_doc_topic_distributions with get_doc_topic_distribution after the fit call to have many opportunities for improved memory usage. I made the following changes:

  1. Made generating the training_doc_topic_distributions optional at the end of the fit call, defaulting to True
  2. Removed the n_samples loop from the get_doc_topic_distribution operation
  3. Reused the posterior_mu, posterior_log_sigma across samples, as they remain the same
  4. Vectorized the sampling process on GPU
  5. Added a get_doc_topic_distribution_iterator operation, which significantly saves memory if the dataset is large by computing topics one batch at a time and yielding them as an iterator
  6. Formatted the ctm.py and decoding_network.py with black, for easier visibility

I do not believe any of the changes I made are breaking changes, these should simply make topic discovery run faster and with less memory.

These efficiency gains enabled me to run the CTM on a massive collection of tweets, so I thought I would create a pull request and offer these improvements back to the original repo, as it still seems active.

vinid commented 1 year ago

thank you so much for this! I'll go over the changes hopefully tomorrow and merge them!

vinid commented 1 year ago

renamed log sigma to log var. I know log sigma is more coherent than var but the variable is referred as log_var in the rest of the code; i'd keep it like this for consistency.

vinid commented 1 year ago

merged!