psanch21 / VAE-GMVAE

This repository contains the implementation of the VAE and Gaussian Mixture VAE using TensorFlow and several network architectures
Apache License 2.0
206 stars 32 forks source link

Why only use w as the input of Pz_wy_graph in GMVAE_graph.py #11

Open MrShellx-1999820 opened 3 years ago

MrShellx-1999820 commented 3 years ago

In GMVAE, p(x,z,w,y)=p(w)p(y)p(z|w,y)p(x|z). As you mentionend that in the generating process, z is generated from p(z|y, w) which means we first need to sample from p(y) and p(w). But in your code (GMVAE_graph.py), the input of the function Pz_wy_graph only contains w, where is y?

psanch21 commented 3 years ago

Hi @MrShellx-1999820,

Thanks for your interest in my work! Notice that the distribution over z is define p(z|w,y)= N(mu_y(w), sigma_y(w)). You can see both parameters (mean and variance) are a function of w and y. InPz_wy_graph it is implemented in the following way. First, we have h = DenseNet(w) (lines 56-68), this part indeed only depends on w. Then, we have two for loops that will output the means and variances for each value of y. I.e., mean_k = DenseNet_k(h) and sigma_k = DenseNet_k(h) for k=1...K. Depending on the value of y, we will select one out of the K sets of parameters for p(z|w,y).

I hope this is helpful!