AntixK / PyTorch-VAE

A Collection of Variational Autoencoders (VAE) in PyTorch.
Apache License 2.0
6.46k stars 1.05k forks source link

Permutation Function for FactorVAE #28

Open DexterJZ opened 3 years ago

DexterJZ commented 3 years ago

According to the original paper by Kim et al., the permutation function permutes across the batch for each dimension. In the case here, if B, D = z.size(), def permute_latent(self, z: Tensor) should permute z along the dimension of B, i. e., z[i, j] = z[new_indices[i], j], where new_indices = torch.randperm(B).

mistycheney commented 2 years ago

I also think this is an error. The permutation should be applied to the batch dimension, rather than the factor dimension.