bshall / VectorQuantizedVAE

A PyTorch implementation of "Continuous Relaxation Training of Discrete Latent Variable Image Models"
MIT License
72 stars 16 forks source link

3 dimensional codebook #4

Open CDitzel opened 3 years ago

CDitzel commented 3 years ago

Hello and thank you for this repo.

I was wondering, if there is a reason to use a 3-dim embedding instead of a 2-dim codebook.

Is the idea to achieve some from of multi head gumbel sampling? The dimension N is taken through all calculations only to be combined with the feature dimension again at the end.

The babmm operation is very expensive but preliminary runs show, at least for me, a superior reconstruction quality as compared to keeping it 2-dimensional.

Thank you in advance

bshall commented 3 years ago

Hi @CDitzel, sorry about the delay. The 3-dim codebook comes from the VQ-VAE paper where they write:

In our experiments we define N discrete latents (e.g., we use a field of 32 x 32 latents for ImageNet, or 8 x 8 x 10 for CIFAR10).

On CIFAR10 it definitely improve performance, but judging from paper it's not necessary for ImageNet (so long as you have a finer spatial resolution).

Is the idea to achieve some from of multi head gumbel sampling?

Yeah, each component of the 3rd dimension can be thought of as a separate codebook with the codes concatenated at the end.

CDitzel commented 3 years ago

thanks for getting back to me on this one. Mh interesting. I am familiar with the paper but did not interpret the quoted lines they way you did. Very interesing =)

As for the gumbel softmax part, I am not entirely sure that it is "correct" to use the distance metrics to the codebook entries to parametrize the posterior to the uniform prior. I believe it would be more appropriate to use the latents, i.e. output of the encoder in the KL but personally I did not manage to get it working then

CDitzel commented 3 years ago

s not necessary for ImageNet (so long as you have a finer spatial resolution).

Can you please elaborate what you mean by spatial resolution?

I am having trouble obtaining convergence on a custom data set with 256x256 images which are downsampled 4 times to 16x16 spatial resolution latent vector by the encoder. Training error decreases as it should but test error plateaus pretty soon and does not improve further. Only used the Gumbel version for this.

Thank you very much in advance