adjidieng / DETM

MIT License
130 stars 39 forks source link

Loss versus KL #7

Closed mona-timmermann closed 3 years ago

Emekaborisama commented 3 years ago

I didn't experience this while using D-ETM

sdspieg commented 3 years ago

'Mona' - sorry to reach out like this here via github, but we've been trying to find people who might be able to help us in reproducing Adji's findings, We have some ongoing projects in which we try to use various NLP tools (including DETM) to build more/better knowledge about various aspects of international relations AND we're writing bunch of new funding proposals along those lines. So if you, or anybody else who catches this, is intrigued/interested in this - please do let me know...

Emekaborisama commented 3 years ago

I really think you need to debug because during my training process my Val PPL wasn't as high as yours and the KL ETA too.

However, i tried it using ADAM optimizer

waahlstrand commented 3 years ago

@mona-timmermann I have experienced the same issues. The consequence of an increasing KL divergence on θ is a drastic drop in topic quality, so it is indeed essential.

My experience is that the DETM, being an LSTM, is very sensitive to parameter choices. Try to lower the learning rate further. The following yields decreasing, or at least more decreasing, KL divergence on all parameters for me:

python main.py --dataset un --data_path ../data_undebates_largev/split_paragraph_0/ --emb_path ../embeddings/un/skipgram_emb_300d.txt --min_df 10 --num_topics 50 --lr 0.0001 --epochs 400 --mode train

(Assuming you use the original data folders).

jfcann commented 3 years ago

Hi all, rising KL_theta isn't, in of itself, a problem - KL_theta is the KL divergence of the approximate posterior attained by variational inference from its prior (standard diagonal normal). The fact that KL_theta rises (and then presumably ~convergences) indicates that the variational inference part of the model has learnt something.