Closed danieltudosiu closed 4 years ago
I think you are mistaken Daniel. That line looks like it converts the embedding from a flattened array. The inputs come in as [B, W, H, D] and only the [B, W, H] are grabbed via ...
*input.shape[:-1]
... which results in an embedding size [B, W, H]
That's exactly the problem that I am saying that in PyTorch the location of the Feature channel is at index 1 not at index -1 as in TensorFlow (in which the original code was developed in).
May it be, that the code in lines 201 and 209 is there to address that issue? https://github.com/rosinality/vq-vae-2-pytorch/blob/e851d8170709cbe0cdc9521a52f5e0516ffece0c/vqvae.py#L201-L212
@iimog Yes you are right about that. I didn't look at the whole example since I wasn't interested in anything but the quantization process.
Hi,
I was going to use your code in one of my projects but the quantization is broken, more specifically this line.
Since you just "copied" the code from TensorFlow you did not account for the change in the order of the dimensions. In TensorFlow, it is [B, W, H, D, F] and in PyTorch is [B, F, W, H, D] (not sure about the w, h, d order per se) thus that line is invalid in PyTorch.
Cheers,
Dan