minyoungg / vqtorch

MIT License
101 stars 9 forks source link

shapes of inputs #1

Closed kashif closed 1 year ago

kashif commented 1 year ago

hello! currently, the inputs to the layers assume that it is an image, but it would be great if one can use it for a sequence of vectors (which is what you do I suppose internally).

minyoungg commented 1 year ago

It should work for arbitrary tensor shapes. Originally it supported up to 3-dimensional tensors, which were intended to be used for transformers. It should work for 2D cases now.

# sequence modeling case: [batch_size x num_tokens x feat_size]
vq_layer = VectorQuant(feature_size=32, num_codes=512, dim=-1).cuda()
z_e = torch.randn(16, 8, 32).cuda()
vq_layer.codebook.weight.data[:128] = z_e.view(-1, 32).data.clone()
z_q, vq_dict = vq_layer(z_e)
assert torch.allclose(z_q, z_e)

# 2d matrix case: [batch_size x feat_size]
vq_layer = VectorQuant(feature_size=32, num_codes=512, dim=-1).cuda()
z_e = torch.randn(128, 32).cuda()
vq_layer.codebook.weight.data[:128] = z_e.data.clone()
z_q, vq_dict = vq_layer(z_e)
assert torch.allclose(z_q, z_e)

Does this solve this for your case?

kashif commented 1 year ago

ah ok... i actually never tested... my bad!

minyoungg commented 1 year ago

feel free to open up the issue back up if you run into any issues.