lucidrains / vector-quantize-pytorch

Vector (and Scalar) Quantization, in Pytorch
MIT License
2.44k stars 197 forks source link

Potential Bug for Lower codebook Dimension/ get_output_from_indices (improved VQGAN) #113

Closed Nikolai10 closed 5 months ago

Nikolai10 commented 6 months ago

Dear @lucidrains,

throughout my experimentations with this wonderful library, I found some weird behaviour when using lower codebook dimensions;

# everything okay
import torch
from vector_quantize_pytorch import VectorQuantize

vq = VectorQuantize(
    dim = 256,
    codebook_size = 256,
    codebook_dim = 16      # paper proposes setting this to 32 or as low as 8 to increase codebook usage
)

vq.eval()

x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = vq(x)

When I try to recover the output from the indices, the script crashes:

# RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x16384 and 16x256).
output = vq.get_output_from_indices(indices)
print(output.shape)

I tracked this down to get_codes_from_indices.

if not is_multiheaded:
  codes = codebook[indices]
  return rearrange(codes, '... h d -> ... (h d)') # why is this line required?

The reason is that we need to call self.project_out for dim != codebook_dim.

If I remove rearrange(codes, '... h d -> ... (h d)'), everything works as expected:

# codes = codebook[indices] # shape (1, 1024, 16)
proj_out = vq.project_out(codes)

# returns True
torch.all(quantized == proj_out)

Please find the full example on Colab. In case I did a mistake, I apologize, I am still new to PyTorch/ this library...

Thanks, Nikolai

lucidrains commented 5 months ago

@Nikolai10 hey Nikolai! thanks for reporting this! i ran into this same issue for another project and ended up fixing it here

Nikolai10 commented 5 months ago

Great, thanks!