Open ZhanYangen opened 3 years ago
I think you just need to reshape the min_encoding_indices to something the PixelCNN can work with. I use this code (bsz is the batch size, s is the spatial dim of z, and c is the channel dim of z):
min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
min_encoding_indices_batched = min_encoding_indices.view(
bsz, s*c//self.e_dim, s*c//self.e_dim, 1)
min_encoding_indices_batched = min_encoding_indices_batched.permute(
0, 3, 1, 2)`
You should be able to train a PixelCNN with the min_encoding_indices_batched as the input. To sample you shape it back to min_encoding_indices and proceed with the code in the repo
Hi, Having trained the VQVAE model, I managed to extract the encoded code, i.e.
min_encoding_indices
. However, I found that its shape is (batch_sizeimage_size/16, 1). That is, take a batch of 10 images with 6464 pixels as an example, the shape of this variable would be (101616, 1), i.e. (2560, 1). This confuses me a little, because in the following code when I need to use it to train PixelCNN, after runingfor batch_idx, (x, label) in enumerate(test_loader):
the shape of x would be (batch_size, 1). Besides, I don't quite understand why labels all filled with zero are needed. Is there any modification needed to this code? Thanks.