rosinality / vq-vae-2-pytorch

Implementation of Generating Diverse High-Fidelity Images with VQ-VAE-2 in PyTorch
Other
1.65k stars 275 forks source link

Quantization process is wrong #28

Closed danieltudosiu closed 4 years ago

danieltudosiu commented 4 years ago

Hi,

I was going to use your code in one of my projects but the quantization is broken, more specifically this line.

Since you just "copied" the code from TensorFlow you did not account for the change in the order of the dimensions. In TensorFlow, it is [B, W, H, D, F] and in PyTorch is [B, F, W, H, D] (not sure about the w, h, d order per se) thus that line is invalid in PyTorch.

Cheers,

Dan

wrrogers commented 4 years ago

I think you are mistaken Daniel. That line looks like it converts the embedding from a flattened array. The inputs come in as [B, W, H, D] and only the [B, W, H] are grabbed via ... *input.shape[:-1] ... which results in an embedding size [B, W, H]

danieltudosiu commented 4 years ago

That's exactly the problem that I am saying that in PyTorch the location of the Feature channel is at index 1 not at index -1 as in TensorFlow (in which the original code was developed in).

iimog commented 4 years ago

May it be, that the code in lines 201 and 209 is there to address that issue? https://github.com/rosinality/vq-vae-2-pytorch/blob/e851d8170709cbe0cdc9521a52f5e0516ffece0c/vqvae.py#L201-L212

danieltudosiu commented 4 years ago

@iimog Yes you are right about that. I didn't look at the whole example since I wasn't interested in anything but the quantization process.