I've looked into the code for the implementation and saw this line. Here, the point is to get unpaired samples from x_i and zm_i which would intentionally be different than (x_i, z_i). However, torch.randperm() does not guarantee derangement, see the following code snippet.
As you can see in the 3rd output we have 0th pos being 0, so derangement is not satisfied, and also in the 4th output both 0 and 2 doesn't satisfy derangement. In my own implementation I directly used a naive approach and made the permutation as j = i + 1, and manually replacing the last item as 0, this naive and simple approach guarantees derangement, but the given batch is important here as it may introduce a bias if the given batch is not well randomized. Please correct me if I'm mistaken, or misunderstood the algorithm/implementation.
I've looked into the code for the implementation and saw this line. Here, the point is to get unpaired samples from x_i and zm_i which would intentionally be different than (x_i, z_i). However,
torch.randperm()
does not guarantee derangement, see the following code snippet.As you can see in the 3rd output we have 0th pos being 0, so derangement is not satisfied, and also in the 4th output both 0 and 2 doesn't satisfy derangement. In my own implementation I directly used a naive approach and made the permutation as j = i + 1, and manually replacing the last item as 0, this naive and simple approach guarantees derangement, but the given batch is important here as it may introduce a bias if the given batch is not well randomized. Please correct me if I'm mistaken, or misunderstood the algorithm/implementation.