MishaLaskin / vqvae

A pytorch implementation of the vector quantized variational autoencoder (https://arxiv.org/abs/1711.00937)
641 stars 76 forks source link

About the shape of variable 'min_encoding_indices' #4

Open ZhanYangen opened 3 years ago

ZhanYangen commented 3 years ago

Hi, Having trained the VQVAE model, I managed to extract the encoded code, i.e. min_encoding_indices. However, I found that its shape is (batch_sizeimage_size/16, 1). That is, take a batch of 10 images with 6464 pixels as an example, the shape of this variable would be (101616, 1), i.e. (2560, 1). This confuses me a little, because in the following code when I need to use it to train PixelCNN, after runing for batch_idx, (x, label) in enumerate(test_loader): the shape of x would be (batch_size, 1). Besides, I don't quite understand why labels all filled with zero are needed. Is there any modification needed to this code? Thanks.

eiriksteen commented 9 months ago

I think you just need to reshape the min_encoding_indices to something the PixelCNN can work with. I use this code (bsz is the batch size, s is the spatial dim of z, and c is the channel dim of z):

min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
min_encoding_indices_batched = min_encoding_indices.view(
          bsz, s*c//self.e_dim, s*c//self.e_dim, 1)
min_encoding_indices_batched = min_encoding_indices_batched.permute(
          0, 3, 1, 2)`

You should be able to train a PixelCNN with the min_encoding_indices_batched as the input. To sample you shape it back to min_encoding_indices and proceed with the code in the repo