zalandoresearch / pytorch-vq-vae

PyTorch implementation of VQ-VAE by Aäron van den Oord et al.
MIT License
534 stars 101 forks source link

dimension issue #8

Open jlian2 opened 4 years ago

jlian2 commented 4 years ago

``
def forward(self, inputs):

convert inputs from BCHW -> BHWC

    inputs = inputs.permute(0, 2, 3, 1).contiguous()
    input_shape = inputs.shape

    # Flatten input
    flat_input = inputs.view(-1, self._embedding_dim)

`` My unders understanding is: dimension of flat_input should be BHWC*embedding_dim, one dimension seems to be missing? Or you are saying number of channels equal to embedding_dim?