tolstikhin / wae

Wasserstein Auto-Encoders
BSD 3-Clause "New" or "Revised" License
505 stars 90 forks source link

MMD for multi-channel latent #4

Closed ahmed-fau closed 5 years ago

ahmed-fau commented 6 years ago

Hi,

I am investigating how to implement WAE for a fully convolutional encoder and decoder such that there is no fully connected layers being used. Assuming that I am working with 1D data, I have a latent code (i.e. output of the bottleneck) with dimensionality [batch_size X 1024 channels X 8 samples_per_channel].

I have used the implementation of imq_kernel shown here which is similar to your implementation . However, this is only working with 2D data (no channel dimension).

My question: is it OK to sum every tensor along the channel dimension (which will lead to matrix of size [batch_size X number_of_samples] then continue as usual ? or there is no way from using fully connected layer to flatten the tensor before mmd calculation ?

When I have used that approach, I got very unstable values for the mmd_loss, the values are fluctuating between positive and negative and there is no monotonic decrease !

Finally: in case of using fully conv layers like what I have described, does the latent code represent a samples from distribution Q(z|x) that should be compared w.r.t Gaussian noise of mean 0 and covariance I ? Or the latent in all cases shall be a vector of means and matrix of co-variance that should be fitted onto Gaussian (which is ur case when using FC layers)?! Does the situation (of fitting mean and covar onto Gaussian) happen also in case of WAE-GAN ?

Sorry for prolonging and many thanks in advance

tolstikhin commented 6 years ago

Hey,

it's hard for me to answer, because I think I'm misunderstanding some parts of your question. Let me clarify:

[----- You wrote -----] Assuming that I am working with 1D data, I have a latent code (i.e. output of the bottleneck) with dimensionality [batch_size X 1024 channels X 8 samples_per_channel]. [----- ---------- -----] The latent space (the "bottleneck space", or where the encoder spits out the points) can in principle be any space, but if you want to use Gaussian prior (and from your further questions it indeed seems this is the case) it should be the d-dimensionsl Euclidean space R^d. If you are mini-batching, the encoder should then spit out vectors of shape [batch_size X d].

[----- You wrote -----] I have used the implementation of imq_kernel shown here which is similar to your implementation . However, this is only working with 2D data (no channel dimension). [----- ---------- -----] I am not familiar with that code, but from a quick review I see it may have some issues. For instance, these lines seem to forget taking squares of the coordinates before summing them in the norms computation. If that indeed is a bug, this may be causing some instabilities...

[----- You wrote -----] Finally: in case of using fully conv layers like what I have described, does the latent code represent a samples from distribution Q(z|x) that should be compared w.r.t Gaussian noise of mean 0 and covariance I ? Or the latent in all cases shall be a vector of means and matrix of co-variance that should be fitted onto Gaussian (which is ur case when using FC layers)?! Does the situation (of fitting mean and covar onto Gaussian) happen also in case of WAE-GAN ? [----- ---------- -----] I am a bit confused with the terminology you are using. Let me try clarifying. The goal of the WAE regularizer is to make sure that the aggregated posterior, which is the distribution in the latent space obtained by first sampling the data point and then encoding it, matches the prior distribution. It is not the conditional Q(Z|X), which you should match with the Gaussian, but the mixture of these conditionals across multiple input points X sampled from the data.

Let me know if there are still questions remaining.

ahmed-fau commented 6 years ago

Sorry for my complicated description.

What I was doing is simply encode the input mini-batch (without using fully connected layer at the encoder output, so the output is still a multi-channel tensor). Then I compare this encoded representation (which is multi-channel) with random Gaussian with the same dimensionality in terms of mmd measure to get the regularization term. Finally I feed that multi-channel encoded representation to the decoder network in order to reconstruct the signal.

Accordingly, I was dealing with that encoded representation as the 'samplings' from Q(Z|X). Does this make sense ?

ahmed-fau commented 6 years ago

"aggregated posterior, which is the distribution in the latent space"

I think this sentence means that the encoder output 'must be' parameters of distribution Q(Z|X) (like what is in VAE), not samples from the distribution Q(Z|X) as what I did !

tolstikhin commented 6 years ago

Hey,

[----- You wrote -----] What I was doing is simply encode the input mini-batch (without using fully connected layer at the encoder output, so the output is still a multi-channel tensor). Then I compare this encoded representation (which is multi-channel) with random Gaussian with the same dimensionality in terms of mmd measure to get the regularization term. Finally I feed that multi-channel encoded representation to the decoder network in order to reconstruct the signal. [----- ---------- -----]

Yeah, sure, it makes sense now. So, instead of having a (batch_size, latent_dim)-shaped tensor after encoding as in our implementation, you are ending up with, say (batch_size, latent_dim1, latent_dim2)-shaped one. In other words, for a given input, the code in your case is not a latent_dim-dimensional vector but rather a higher dimensional tensor.

It's hard to say for sure, but I don't see any problems with this approach. You should only be careful when defining the regularization penalties, especially the MMD-based ones which may be sensitive to the latent space dimensionality. Personally, I would flatten the codes before computing the regularizer. Say, your input gets encoded to a (dim1, dim2)-shaped tensor. In order to compute an MMD loss, I would take a (batch_size, dim1, dim2)-shaped tensor representing encoded training points, reshape it to (batch_size, dim1 dim2), and then compare it to the tensor of the same shape populated with the independent standard Gaussian entries using these lines. While doing so you will need to replace opts['zdim'] with dim1 dim2.

Meanwhile the decoder can operate directly with the (batch_size, dim1, dim2)-shaped input.

[----- You wrote -----] "aggregated posterior, which is the distribution in the latent space"

I think this sentence means that the encoder output 'must be' parameters of distribution Q(Z|X) (like what is in VAE), not samples from the distribution Q(Z|X) as what I did ! [----- ---------- -----] No, actually either of these is fine, you can use Gaussian encoder, any other stochastic encoder, or even a deterministic encoder. All you need to verify is that you can sample from Q(Z|X=x) for any given input point x. If your encoder is parametrizing mean and covariance of a Gaussian, this sampling will be defined by (a) running the encoder for an input x to get the Gaussian parameters, (b) sample a random code from a Gaussian corresponding. If your encoder is deterministic, then Q(Z|X=x) is just a function mapping an input x to a fixed code.

Note that if you are using a deterministic encoder, Q(Z|X) is a Dirac point mass and obviously can not match the standard Gaussian. However, the aggregate posterior is a continuous mixture of the point masses.

I hope this clarifies some of your questions!

ahmed-fau commented 6 years ago

The first part of your answer is totally clear for me now and similar to what I though about (i.e. to exclude FC from encoder and deal with the coded representation directly to calculate mmd after reshaping the multi-channel encoded representation like what you have described).

'Note that if you are using a deterministic encoder, Q(Z|X) is a Dirac point mass and obviously can not match the standard Gaussian. However, the aggregate posterior is a continuous mixture of the point masses.'

This is the only thing remaining to understand, what does it mean by 'aggregate posterior' ? is it the mean of codes Q(Z|X) calculated over mini-batch ? i.e.

aggregate posterior= 1/n sum(Q(Z|X=x_i)) where n is the mini-batch size and i is the sample index of the mini-batch

if so, then increasing the mini-batch size should lead to better mmd measure

tolstikhin commented 6 years ago

When I say aggregate posterior, ideally I mean marginal distribution of a code Z when (a) you sample your data point x from the unknown data distribution, and (b) you sample a code Z from Q(Z|X=x).

When computing MMD we are using a U-statistic, which is a sample-based estimate of the true MMD. In other words, we are estimating the population MMD between actual aggregate posterior QZ and the prior PZ (say, standard Gaussian) based on samples from QZ and PZ. Both samples are of size batch_size. Loosely speaking, we are replacing distributions Qz and Pz with two point-clouds.

This is the only thing remaining to understand, what does it mean by 'aggregate posterior' ? is it the mean of codes Q(Z|X) calculated over mini-batch ? i.e.

No, this is not the case. There is no averaging of vectors in the latent space going on. Instead, as can be seen from the Algorithm 2, what happens is an averaging of pairwise similarity measures between points fro, those two clouds.

if so, then increasing the mini-batch size should lead to better mmd measure

Indeed, we have reasons to believe that MMD should work better with larger minibatch sizes, but that is not necessary.

ahmed-fau commented 6 years ago

Aha, so the calculation of that aggregate posterior is already done by averaging the mmd estimates between samplings from Pz and Q(Z|X=x) according to the update role described in algorithm 2 and this should ensure the aggregate posterior to be continues mixture matches the prior.

Cool, then all what I need to do then is to apply that idea of reshaping the multi-channel latent (encoded representation).

tolstikhin commented 6 years ago

I'm glad we clarified this. Let me know if you have any further questions.