throughout my experimentations with this wonderful library, I found some weird behaviour when using lower codebook dimensions;
# everything okay
import torch
from vector_quantize_pytorch import VectorQuantize
vq = VectorQuantize(
dim = 256,
codebook_size = 256,
codebook_dim = 16 # paper proposes setting this to 32 or as low as 8 to increase codebook usage
)
vq.eval()
x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = vq(x)
When I try to recover the output from the indices, the script crashes:
# RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x16384 and 16x256).
output = vq.get_output_from_indices(indices)
print(output.shape)
Dear @lucidrains,
throughout my experimentations with this wonderful library, I found some weird behaviour when using lower codebook dimensions;
When I try to recover the output from the indices, the script crashes:
I tracked this down to get_codes_from_indices.
The reason is that we need to call self.project_out for dim != codebook_dim.
If I remove
rearrange(codes, '... h d -> ... (h d)')
, everything works as expected:Please find the full example on Colab. In case I did a mistake, I apologize, I am still new to PyTorch/ this library...
Thanks, Nikolai