Open asher-bit opened 5 months ago
try this way! `import torch from vector_quantize_pytorch import ResidualVQ
residual_vq = ResidualVQ( dim = 256, codebook_size = 256, num_quantizers = 2, kmeans_init = True, kmeans_iters = 10 )
x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = residual_vq(x)
codebook = residual_vq.codebooks # get the codebook print(indices) print(quantized)
print(codebook)
quantized_out = residual_vq.get_output_from_indices(indices)
print(quantized_out)
torch.save(dict(codebook = codebook, indices = indices), 'codebook.pt') codebook = torch.load('codebook.pt')['codebook'] indices = torch.load('codebook.pt')['indices'] print(codebook) print(indices)
torch.save(residual_vq.state_dict(), "residual_vq.pt")
residual_vq_reinit = ResidualVQ( dim = 256, codebook_size = 256, num_quantizers = 2, kmeans_init = True, kmeans_iters = 10 )
residual_vq_reinit.load_state_dict(torch.load("residual_vq.pt", map_location=torch.device('cpu')))
residual_vq_reinit.codebook = codebook
residual_vq.eval()
print(residual_vq_reinit.codebook)
quantized_out = residual_vq_reinit.get_output_from_indices(indices)
print(quantized_out)`
Hi, I am trying to get quantized by saved codebook and index, but self.codebooks will always get 0. How can I do it?