Linear95 / CLUB

Code for ICML2020 paper - CLUB: A Contrastive Log-ratio Upper Bound of Mutual Information
312 stars 39 forks source link

About the logvar prediction #12

Open daxintan-cuhk opened 3 years ago

daxintan-cuhk commented 3 years ago

Thank you for your excellent code! I have encountered some problem when I use the mutual information constraint in a speech processing task. In the process of the training, I found that the logvar prediction network, whose last layer is 'Tanh', always output the '-1', no matter what the input is. And the overall mutual information prediction network seems to lose effect, as the loglikelihood of the positive sample in the training batch is all very small value, something like -1,000,000. Does other user meet this problems before? Or do you have any advice? Thank you a lot!

Yours, Daxin

Linear95 commented 3 years ago

Hi Daxin,

For the output of logvar, you can try any other activation functions you want. Here I guess the main problem in your case is your q(y|x) network is not well-learned before doing the mi minimization. The probable solution might be enlarging the learning rate for q(y|x)'s parameters, or increasing the training step of CLUB within each mi minimization iteration.

Thanks and good luck!

guanyadong commented 3 years ago

Thank you for your excellent code! I have encountered some problem when I use the mutual information constraint in a speech processing task. In the process of the training, I found that the logvar prediction network, whose last layer is 'Tanh', always output the '-1', no matter what the input is. And the overall mutual information prediction network seems to lose effect, as the loglikelihood of the positive sample in the training batch is all very small value, something like -1,000,000. Does other user meet this problems before? Or do you have any advice? Thank you a lot!

Yours, Daxin

I also encountered this problem, have you fixed it?

gaoxinrui commented 2 years ago

Hi, thanks for the good work. I have a general question: according to your code, it is actually a Gaussian distribution is estimated using NN. However, for Gaussian, the first two moments can be calculated directly from samples. So what is the advantage of using NN to estimate it? Further, do you have an idea of how to estimate a general distribution using NN?

Thanks a lot.

Linear95 commented 2 years ago

Hi, thanks for the good work. I have a general question: according to your code, it is actually a Gaussian distribution is estimated using NN. However, for Gaussian, the first two moments can be calculated directly from samples. So what is the advantage of using NN to estimate it? Further, do you have an idea of how to estimate a general distribution using NN?

Thanks a lot.

In this work, we do not try to directly estimate a general Gaussian distribution from samples. Instead, we aim to estimate the conditional distribution p(Y|X) with a variational neural network. In our setups, given each value of X=X0, the conditional distribution p(Y|X=X0) is a Gaussian distribution. What we want the neural network to learn is not one Gaussian distribution p(Y|X=X0) (which as you said, you can estimate with moments), but the relation between X and Y, so that given each X=X0 value (even if the value X0 is unseen in samples) we can approximate the conditional distribution P(Y|X=X0). We do not have any constraint on the marginal distribution of Y, as P(Y).

To estimate a general Gaussian distribution with samples, there are plenty of prior works. The moment estimation can be one of the methods (which is similar to max likelihood estimation (MLE)). However, the estimation of the covariance matrix is quite complicated, which is also difficult to calculate the density function when the sample dimension is high.

Estimating a general distribution is also an interesting topic. To my knowledge, Generative Adversarial Networks (GANs) can directly draw nice samples from the distribution of given sample data. If you want to obtain the density function of a general distribution from samples, you can check methods such as kernel density estimation (KDE).

gaoxinrui commented 2 years ago

Thanks. I understand that you were calculating the conditional distribution, which is assumed to be Gaussian. For Gaussian conditional distribution p(Y|X), the optimal prediction of Y, i.e., the conditional expectation E(Y|X), is essentially a linear function of X. This linear function is related to Pearson's correlation ρ, as in your recently uploaded Mutual Information Minimization Demo. The Pearson's correlation can easily be obtained, no need to use a NN to estimate. If the relationship between X and Y is strongly nonlinear, which will lead to a non-Gaussian conditional distribution,I wonder if the method still works well.

Linear95 commented 2 years ago

Thanks. I understand that you were calculating the conditional distribution, which is assumed to be Gaussian. For Gaussian conditional distribution p(Y|X), the optimal prediction of Y, i.e., the conditional expectation E(Y|X), is essentially a linear function of X. This linear function is related to Pearson's correlation ρ, as in your recently uploaded Mutual Information Minimization Demo. The Pearson's correlation can easily be obtained, no need to use a NN to estimate. If the relationship between X and Y is strongly nonlinear, which will lead to a non-Gaussian conditional distribution,I wonder if the method still works well.

Good question. That is also the reason why we introduce a variational NN to approximate p(Y|X) as q_\theta(Y|X) in our paper. To handle the non-linearity between X and Y, we parameterize p(Y|X) as N(mu(x), sigma^2(x)), so that we can non-linearly predict the mean mu(x) and variance sigma^2(x) of the conditional gaussian as the NN's outputs with X as the input.

bonehan commented 2 years ago

Hi, thanks for the good work. I have a general question: according to your code, the positive term in the pytorch version minors a term of logvar but in ther tensorflow version it doesn't. Does it remain any tips in this two versions? And I also encounter a problem in MI minimization that the MI in the earlier training epoches is always <0, is it resonable and any tips to slove it?

Linear95 commented 2 years ago

Calculating CLUB without log-var is equivalent to setting the variance of the conditional Gaussian p(y|x) as 1, which is still within our theoretical framework. By fixing the variance of p(y|x), we can obtain a more stable but less flexible MI estimation. For negative MI estimation, you check my suggestion here.

LindgeW commented 7 months ago

When var is set to 1, will it lead to any performance degradation?

LindgeW commented 7 months ago

Is the tanh activation function of the logvar required? Can you remove it or just replace it with something else?