lucidrains / vector-quantize-pytorch

Vector (and Scalar) Quantization, in Pytorch
MIT License
2.12k stars 179 forks source link

How to get the original code by saved codebook and index #102

Open asher-bit opened 5 months ago

asher-bit commented 5 months ago

Hi, I am trying to get quantized by saved codebook and index, but self.codebooks will always get 0. How can I do it?

import torch
from vector_quantize_pytorch import ResidualVQ

residual_vq = ResidualVQ(
    dim = 256,
    codebook_size = 256,
    num_quantizers = 2,
    kmeans_init = True,
    kmeans_iters = 10
)

x = torch.randn(1, 1024, 256)

residual_vq.eval()

quantized, indices, commit_loss = residual_vq(x)

codebook = residual_vq.codebooks  # get the codebook

torch.save(dict(codebook = codebook, indices = indices), 'codebook.pt')
codebook = torch.load('codebook.pt')['codebook']
indices = torch.load('codebook.pt')['indices']

residual_vq_reinit = ResidualVQ(
    dim = 256,
    codebook_size = 256,
    num_quantizers = 2,
    kmeans_init = True,
    kmeans_iters = 10
)
residual_vq_reinit.codebook = codebook

quantized_out = residual_vq_reinit.get_codes_from_indices(indices)

assert torch.all(quantized == quantized_out.sum(dim = 0))
crazyrayLing commented 5 months ago

try this way! `import torch from vector_quantize_pytorch import ResidualVQ

residual_vq = ResidualVQ( dim = 256, codebook_size = 256, num_quantizers = 2, kmeans_init = True, kmeans_iters = 10 )

x = torch.randn(1, 1024, 256)

quantized, indices, commit_loss = residual_vq(x)

codebook = residual_vq.codebooks # get the codebook print(indices) print(quantized)

print(codebook)

quantized_out = residual_vq.get_output_from_indices(indices)

print(quantized_out)

torch.save(dict(codebook = codebook, indices = indices), 'codebook.pt') codebook = torch.load('codebook.pt')['codebook'] indices = torch.load('codebook.pt')['indices'] print(codebook) print(indices)

torch.save(residual_vq.state_dict(), "residual_vq.pt")

residual_vq_reinit = ResidualVQ( dim = 256, codebook_size = 256, num_quantizers = 2, kmeans_init = True, kmeans_iters = 10 )

residual_vq_reinit.load_state_dict(torch.load("residual_vq.pt", map_location=torch.device('cpu')))

residual_vq_reinit.codebook = codebook

residual_vq.eval()

print(residual_vq_reinit.codebook)

quantized_out = residual_vq_reinit.get_output_from_indices(indices)

print(quantized_out)`