According to the original paper by Kim et al., the permutation function permutes across the batch for each dimension. In the case here, if B, D = z.size(), def permute_latent(self, z: Tensor) should permute z along the dimension of B, i. e., z[i, j] = z[new_indices[i], j], where new_indices = torch.randperm(B).
According to the original paper by Kim et al., the permutation function permutes across the batch for each dimension. In the case here, if
B, D = z.size()
,def permute_latent(self, z: Tensor)
should permute z along the dimension ofB
, i. e.,z[i, j] = z[new_indices[i], j]
, wherenew_indices = torch.randperm(B)
.