Closed xuChenSJTU closed 4 years ago
@xuChenSJTU Hi, I used MIB on my own dataset, and the loss also didn't drop. I'm wondering whether it is a problem. thanks a lot!
Hi @way151, In the current implementation, the loss is expressed as the sum of two components 1) -I(z1;z2) 2) \beta * SKL(p(z1|v1)||p(z2|v2)) Since beta is increasing over time due to the annealing schedule, it is possible to observe an increase in the total loss. I would recommend logging the two components separately to verify that SKL(p(z1|v1)||p(z2|v2)) decreases when \beta gets bigger.
I hope this helps to solve your problem but feel free to ask more questions.
Hi @way151, In the current implementation, the loss is expressed as the sum of two components
1. -I(z1;z2) 2. \beta * SKL(p(z1|v1)||p(z2|v2)) Since beta is increasing over time due to the annealing schedule, it is possible to observe an increase in the total loss. I would recommend logging the two components separately to verify that SKL(p(z1|v1)||p(z2|v2)) decreases when \beta gets bigger.
I hope this helps to solve your problem but feel free to ask more questions.
Thanks for your reply! I will log the two losses separately. In fact, when I found L_MIB not drop, the beta was at a fixed value (1e-4). In addition, the loss has a very tiny floating between 1.3863-1.3867, but I checked the gradient of encoders and critic module, and all of them have normal gradient values (about 1e-5~1e-7) when backwarding. Would you give me some suggestion about it? Thanks!
I just logged the two parts of L_MIB and I found at the first epoch, the second term SKL droped quickly (then converged to 0) but the first term kept at 1.386 around. Additionally, when beta=0, the encoders and critic module still have normal gradients.
Update: Another finding is when I forward to compute and log SKL loss but not backward it (which means it makes no gradients), the loss will drop and converge to 0 at a lower speed compared to when beta > 0. I wonder it is a normal situation or my implementation went wrong. Thanks a lot.
It seems that your model is struggling to detect what the two views have in common. I would recommend removing the regularization term first (\beta=0) and checking that the mutual information estimator is able to extract the common features ( I(z1;z2)>>0). Then:
I hope this helps.
Yes. I logged and checked mi_estimation
term in the code and it was almost 0 (less than 1e-3) and sometimes was negative. My dataset is multimodal so maybe I over-estimated the common information between them. One last question: would you please give me some recommendations about other powerful estimator networks?
If your mutual information is small, it means that the two views have really little common information or this information is really hard to access.
From my personal experience estimators that express mutual information as a dot product of features ("separable" like info NCE) are faster but less powerful when compared to the ones that make use of a single network that takes both representations as an input (as InfoMAX and the current implementation in this repository). Increasing the depth of the estimator could help at the cost of a longer training procedure. For more information regarding which estimator to use I recommend: https://arxiv.org/abs/1905.06922
What could also help is to use BatchNorm or other norms, more layers, and a larger representation size for the encoder architectures. In case your views have different marginal distribution, make sure you have two distinct encoder architectures or only limited weight sharing.
@mfederici Hi, I followed your suggestions. I tried using a shared encoder and tuning some hyper-parameters, and the loss finally dropped! I've tried infoNCE before and it didn't work well on my dataset in my setting. What I need is to mine common information for downstream tasks in a supervised learning setting while keeping performances of both modalities not decreasing. Thank you for your patience and help.
I am glad to hear that you could solve the issue! If you are working in supervised settings you could also add a supervised component to the loss to help extract predictive information even if this could potentially come at the cost of an increased discrepancy between the two views.
HI,
Thanks for sharing the codes. I just run your codes on MNIST, but it seems the loss does not drop. This is the training log:
Train Loss: 16.65132975578308 Train Accuracy: 0.440000 Test Accuracy: 0.358600 Storing model checkpoint Updating the model backup Train Loss: 16.20352667570114 Train Accuracy: 0.520000 Test Accuracy: 0.420100 Storing model checkpoint Updating the model backup Train Loss: 18.444831013679504 Train Accuracy: 0.480000 Test Accuracy: 0.403400 Storing model checkpoint Updating the model backup Train Loss: 24.42823952436447 Train Accuracy: 0.500000 Test Accuracy: 0.411100 Storing model checkpoint Updating the model backup Train Loss: 23.091638684272766 Train Accuracy: 0.600000 Test Accuracy: 0.466200 Storing model checkpoint Updating the model backup Train Loss: 17.668635308742523 Train Accuracy: 0.640000 Test Accuracy: 0.514900 Storing model checkpoint Updating the model backup Train Loss: 26.7216295003891 Train Accuracy: 0.650000 Test Accuracy: 0.520900 Storing model checkpoint Updating the model backup Train Loss: 24.259143024683 Train Accuracy: 0.640000 Test Accuracy: 0.554900 Storing model checkpoint Updating the model backup Train Loss: 24.502391815185547 Train Accuracy: 0.650000 Test Accuracy: 0.539600 Storing model checkpoint Updating the model backup Train Loss: 28.630192399024963 Train Accuracy: 0.690000 Test Accuracy: 0.579100 Storing model checkpoint Updating the model backup Train Loss: 26.80162337422371 Train Accuracy: 0.710000 Test Accuracy: 0.584000 Storing model checkpoint Updating the model backup Train Loss: 18.396559059619904 Train Accuracy: 0.740000 Test Accuracy: 0.590100 Storing model checkpoint Updating the model backup Train Loss: 21.14952301979065 Train Accuracy: 0.720000 Test Accuracy: 0.597600 Storing model checkpoint Updating the model backup Train Loss: 28.22284960746765 Train Accuracy: 0.710000 Test Accuracy: 0.609900 Storing model checkpoint Updating the model backup Train Loss: 19.738620832562447 Train Accuracy: 0.740000 Test Accuracy: 0.598700 Storing model checkpoint Updating the model backup Train Loss: 26.586383253335953 Train Accuracy: 0.720000 Test Accuracy: 0.610100 Storing model checkpoint Updating the model backup Train Loss: 27.010501861572266 Train Accuracy: 0.710000 Test Accuracy: 0.615600 Storing model checkpoint Updating the model backup Train Loss: 24.147621154785156 Train Accuracy: 0.730000 Test Accuracy: 0.617300 Storing model checkpoint Updating the model backup Train Loss: 16.216432973742485 Train Accuracy: 0.720000 Test Accuracy: 0.626300 Storing model checkpoint Updating the model backup Train Loss: 29.65166562795639 Train Accuracy: 0.690000 Test Accuracy: 0.623700 Storing model checkpoint Updating the model backup Train Loss: 19.953528456389904 Train Accuracy: 0.740000 Test Accuracy: 0.630500 Storing model checkpoint Updating the model backup Train Loss: 24.09559604525566 Train Accuracy: 0.730000 Test Accuracy: 0.614100 Storing model checkpoint Updating the model backup Train Loss: 27.629364281892776 Train Accuracy: 0.700000 Test Accuracy: 0.634300 Storing model checkpoint Updating the model backup Train Loss: 24.558568745851517 Train Accuracy: 0.740000 Test Accuracy: 0.642000 Storing model checkpoint Updating the model backup Train Loss: 23.52346968650818 Train Accuracy: 0.730000 Test Accuracy: 0.649100 Storing model checkpoint Updating the model backup Train Loss: 28.106444120407104 Train Accuracy: 0.750000 Test Accuracy: 0.654100 Storing model checkpoint Updating the model backup Train Loss: 16.491002649068832 Train Accuracy: 0.740000 Test Accuracy: 0.653300 Storing model checkpoint Updating the model backup Train Loss: 24.8224079310894 Train Accuracy: 0.760000 Test Accuracy: 0.656900 Storing model checkpoint Updating the model backup Train Loss: 22.03815546631813 Train Accuracy: 0.740000 Test Accuracy: 0.660300 Storing model checkpoint Updating the model backup Train Loss: 21.805272683501244 Train Accuracy: 0.790000 Test Accuracy: 0.678100 Storing model checkpoint Updating the model backup Train Loss: 18.99772535264492 Train Accuracy: 0.790000 Test Accuracy: 0.689200 Storing model checkpoint Updating the model backup Train Loss: 23.89051579684019 Train Accuracy: 0.790000 Test Accuracy: 0.695400 Storing model checkpoint Updating the model backup Train Loss: 19.56764091551304 Train Accuracy: 0.810000 Test Accuracy: 0.713800 Storing model checkpoint Updating the model backup Train Loss: 18.89282363653183 Train Accuracy: 0.810000 Test Accuracy: 0.712700 Storing model checkpoint Updating the model backup Train Loss: 22.087528944015503 Train Accuracy: 0.830000 Test Accuracy: 0.725400 Storing model checkpoint Updating the model backup Train Loss: 20.261658288538456 Train Accuracy: 0.830000 Test Accuracy: 0.729000 Storing model checkpoint Updating the model backup Train Loss: 27.495856314897537 Train Accuracy: 0.870000 Test Accuracy: 0.740800 Storing model checkpoint Updating the model backup Train Loss: 21.786702036857605 Train Accuracy: 0.820000 Test Accuracy: 0.735500
It even increases compared to the former epochs. Anything wrong?
Thanks.