Closed kashif closed 1 year ago
It should work for arbitrary tensor shapes. Originally it supported up to 3-dimensional tensors, which were intended to be used for transformers. It should work for 2D cases now.
# sequence modeling case: [batch_size x num_tokens x feat_size]
vq_layer = VectorQuant(feature_size=32, num_codes=512, dim=-1).cuda()
z_e = torch.randn(16, 8, 32).cuda()
vq_layer.codebook.weight.data[:128] = z_e.view(-1, 32).data.clone()
z_q, vq_dict = vq_layer(z_e)
assert torch.allclose(z_q, z_e)
# 2d matrix case: [batch_size x feat_size]
vq_layer = VectorQuant(feature_size=32, num_codes=512, dim=-1).cuda()
z_e = torch.randn(128, 32).cuda()
vq_layer.codebook.weight.data[:128] = z_e.data.clone()
z_q, vq_dict = vq_layer(z_e)
assert torch.allclose(z_q, z_e)
Does this solve this for your case?
ah ok... i actually never tested... my bad!
feel free to open up the issue back up if you run into any issues.
hello! currently, the inputs to the layers assume that it is an image, but it would be great if one can use it for a sequence of vectors (which is what you do I suppose internally).