gtegner / mine-pytorch

Mutual Information Neural Estimation in Pytorch
MIT License
294 stars 56 forks source link

MINE optimizer is flawed #5

Open devrimcavusoglu opened 1 year ago

devrimcavusoglu commented 1 year ago

I've looked into the code for the implementation and saw this line. Here, the point is to get unpaired samples from x_i and zm_i which would intentionally be different than (x_i, z_i). However, torch.randperm() does not guarantee derangement, see the following code snippet.

import torch

torch.manual_seed(42)
torch.randperm(5)  # Out[1]: tensor([2, 4, 3, 0, 1])
torch.randperm(5)  # Out[2]: tensor([1, 4, 3, 2, 0])
torch.randperm(5)  # Out[3]: tensor([0, 2, 3, 1, 4])
torch.randperm(5)  # Out[4]: tensor([0, 3, 2, 1, 4])

As you can see in the 3rd output we have 0th pos being 0, so derangement is not satisfied, and also in the 4th output both 0 and 2 doesn't satisfy derangement. In my own implementation I directly used a naive approach and made the permutation as j = i + 1, and manually replacing the last item as 0, this naive and simple approach guarantees derangement, but the given batch is important here as it may introduce a bias if the given batch is not well randomized. Please correct me if I'm mistaken, or misunderstood the algorithm/implementation.