Linear95 / CLUB

Code for ICML2020 paper - CLUB: A Contrastive Log-ratio Upper Bound of Mutual Information
312 stars 39 forks source link

Hi, thanks for the good work. I have a general question: according to your code, the positive term in the pytorch version minors a term of logvar but in ther tensorflow version it doesn't. Does it remain any tips in this two versions? And I also encounter a problem in MI minimization that the MI in the earlier training epoches is always <0, is it resonable and any tips to slove it? #17

Open xiaomi4356 opened 1 year ago

xiaomi4356 commented 1 year ago
    Hi, thanks for the good work. I have a general question: according to your code, the positive term in the pytorch version minors a term of logvar but in ther tensorflow version it doesn't. Does it remain any tips in this two versions? And I also encounter a problem in MI minimization that the MI in the earlier training epoches is always <0, is it resonable and any tips to slove it?

Originally posted by @bonehan in https://github.com/Linear95/CLUB/issues/12#issuecomment-1111890242

xiaomi4356 commented 1 year ago

Hello,I also encountered this problem, have you fixed it? In addition, I also met the problem in MI minimization that the lld_loss which is uesd to train the CLUB networh is always <0, have you met the question?

Linear95 commented 1 year ago

Hi, for the question about the logvar term, you can regard CLUB without logvar as it uses a conditional gaussian with a fixed variance (var = 1.) for the variational approximation q(y|x). If the variation is set to 1., then logvar=0. This trick narrows the variational distribution family for learning p(y|x), which provides less flexibility but more stability. I just added a Pytorch version CLUB without logvar into mi_estimators.py, and named it CLUBMean. Feel free to try it in our MI estimation and minimization demo notebooks.

For the problem of the negative MI values, it is usually because the variational approximation is not learned well enough. You might need to increase the learning rate of the number of updating steps for the mi estimator to learn q(y|x). Also, your might need to change the architecture (number of layers or hidden size )of q(y|x) to make the variational net fit your own data samples.