ritheshkumar95 / pytorch-vqvae

Vector Quantized VAEs - PyTorch Implementation
850 stars 138 forks source link

Distance calculation #15

Closed imbalu007 closed 4 years ago

imbalu007 commented 4 years ago

Can you please explain how you are computing the distance between the codebook and inputs? In functions.py, you are using this line: distances = torch.addmm(codebook_sqr + inputs_sqr, inputs_flatten, codebook.t(), alpha=-2.0, beta=1.0)

I am unable to understand how this will give the euclidean distance between inputs and codebook.

tristandeleu commented 4 years ago

This is using the decomposition (codebook - inputs)^2 = codebook^2 + inputs^2 - 2 * codebook * inputs. See the torch.addmm documentation for the order of the arguments (which is a bit misleading).

Computing the distance this way avoids the creation of a (possibly very large) intermediate matrix codebook - inputs, so that you can fit larger inputs / codebooks on the GPU.

Roller44 commented 1 year ago

How about using torch.cdist?