lucidrains / vector-quantize-pytorch

Vector (and Scalar) Quantization, in Pytorch
MIT License
2.65k stars 215 forks source link

Correct `mask` usage #149

Closed dwromero closed 3 months ago

dwromero commented 4 months ago

Hi @lucidrains ,

I ran into a different problem. I am working with sequences that might have different lengths. Therefore, I pad them to the same length.

I want to avoid using codes for these positions and was hoping to use the mask parameter for this. Unfortunately, this leads to some weird error, which happens here:

 if exists(mask):
              mask = repeat(mask, 'b n -> c (b h n)', c = flatten.shape[0], h = flatten.shape[-2] // (mask.shape[0] * mask.shape[1]))

In particular, this "nullifies" my mask.

For example, I am running 3-level VQ on a batched sequence of length [4, 6120]. I am therefore passing a mask of the same size. But then, per this line, the mask is changed to

ipdb> repeat(mask, 'b n -> c (b h n)', c = flatten.shape[0], h = flatten.shape[-2] // (mask.shape[0] * mask.shape[1])).shape
torch.Size([1, 0])

Is this the intended usage? Would you recommend other way of doing this? I can in principle mask the codes afterwards, but these additional codes would still be updated based on the commitment loss, which might lead to unexpected behavior.

I am looking forward to your answer. Once again, thank you for the amazing effort on building this awesome repo! :)

David

lucidrains commented 4 months ago

@dwromero oh that's strange, i don't see an issue locally

do you have a short snippet with the hyperparameters for reproducing this?

dwromero commented 4 months ago

I reproduced it in this notebook: https://colab.research.google.com/drive/1XMqJA7F-WSsWcHS-HCuQQqjuu63Ou8eh?usp=sharing

If I understand correctly, the codes of the out_mask vectors should be -1, and the codes should be zero, since the mask is basically saying to the layer which parts should be ignored. Is this correct?

lucidrains commented 4 months ago

@dwromero hey David, would you like to see if the latest version addresses the issue? also threw in a test here

lucidrains commented 4 months ago

@dwromero i understand why you were expecting -1 for the indices, but i didn't find any other issue otherwise

dwromero commented 4 months ago

@lucidrains just checked and indeed the -1's are returned. I guess this means that these positions are not affecting the commit_loss or any other auxiliary loss anymore, right?

Perhaps it is my own misunderstanding, but shouldn't the codes returned be zero? as these correspond to "padded tokens"?

lucidrains commented 4 months ago

@dwromero i'm not quite sure what the right behavior is

i opted to just return the original input unquantized for those areas with False in the mask

you can always manually zero out by doing a mask[..., None] * output

dwromero commented 4 months ago

Yeah haha I understand. Me neither. I'll ask some colleagues about this :) Let's see what they think

lucidrains commented 4 months ago

@dwromero yea, as long as you zero out the input or output beforehand with that one-liner above, it will return 0s

dwromero commented 3 months ago

Hi @lucidrains , after asking around, people seem to think that the best solution is to return zeros at these positions. The reason being that if you return the original values, it might seem as if the vq reconstructions are perfect, which is not the case and can be confusing. Looking forward to hearing what you think.

lucidrains commented 3 months ago

@dwromero ok done!