Saswatm123 / MMD-VAE

Pytorch implementation of Maximum Mean Discrepancy Variational Autoencoder, a member of the InfoVAE family that maximizes Mutual Information between the Isotropic Gaussian Prior (as the latent space) and the Data Distribution.
54 stars 10 forks source link

posterior collapses #1

Open panmianzhi opened 8 months ago

panmianzhi commented 8 months ago

I am implementing a Stabel Diffusion on my dataset (with weather images of size (200, 320)). I found the images reconstructed by the vanilla VAE - used by Stable Diffusion by default - are very vague. So I try to use MMD-VAE to improve the reconstruction quality. But during my training, I found that the posterior of MMD-VAE always collapses, i.e. the decoder outputs noise if I feed a normal Gaussian variable to it. I am not sure if it is because the hidden dimension of the posterior is too big. Specifically, the posterior has shape (batch_size, 4, 25, 40). When I compute the MMD, I flatten the posteriors and use your code as follows:

def gaussian_kernelx, y):
    x_size = x.size(0)
    y_size = y.size(0)
    dim = x.size(1) * x.size(2) * x.size(3)
    x = x.view(x_size, -1)
    y = y.view(y_size, -1)
    tiled_x = x.unsqueeze(1).expand(x_size, y_size, dim)
    tiled_y = y.unsqueeze(0).expand(x_size, y_size, dim)
    kernel_input = (tiled_x - tiled_y).pow(2).mean(2)/float(dim)
    return torch.exp(-kernel_input)

def MMD(x, y):
    '''
    :param x: (B, C, W, H)
    :param y: (B, C, W, H)
    :return:
    '''
    x_kernel = compute_kernel_origin(x, x)
    y_kernel = compute_kernel_origin(y, y)
    xy_kernel = compute_kernel_origin(x, y)
    mmd = x_kernel.mean() + y_kernel.mean() - 2*xy_kernel.mean()
    return mmd

Can you give me some possible solutions? Thank you very much !

Saswatm123 commented 7 months ago

Here is my opinion on what is happening, and some tests you can run. I meant to answer months ago, but I am not sure why my answer disappeared. The error term consists of the MMD between a normal Gaussian & the observed distribution + the reconstruction error. The reconstruction error can be artificially made low if the network doesn't care about the MMD penalty term, effectively making it a "regular Autoencoder". As a test, while you are training, watch the MMD over time, and see if it gets to an acceptably low level - I would guess that it does not. If it does, then the reconstruction error is high the whole time, and the network doesn't learn to properly encode or decode your data properly. In this case, to test out whether it is the encoder or decoder that is an issue, if you have data labels, train another classifier on the latent distribution and see if it can properly classify the latent data (like an SVM or something). If it cannot, then the encoder has an under powering issue and is not clustering properly. If it can properly classify the latent data, then the decoder has the same under powering issue.