lucidrains / vector-quantize-pytorch

Vector (and Scalar) Quantization, in Pytorch
MIT License
2.6k stars 207 forks source link

Bug when using flag `orthogonal_reg_active_codes_only` #43

Closed trfnhle closed 1 year ago

trfnhle commented 1 year ago

Hello, thank you for the great work of VQ-VAE. While reading your implementation when turning on the flag orthogonal_reg_active_codes_only. https://github.com/lucidrains/vector-quantize-pytorch/blob/b449efc35c0414d0338752d6a81b44142a8779af/vector_quantize_pytorch/vector_quantize_pytorch.py#L592-L603 I think it does not work properly. The codebook shape after the line codebook = self._codebook.embed is [num_codebook, codebook_size, codebook_dim]. Therefore this line codebook = codebook[unique_code_ids] should be codebook = codebook[:, unique_code_ids] and num_codes = codebook.shape[0] should be num_codes = codebook.shape[1]. Am I correct? Overall this above code should be

                if self.orthogonal_reg_active_codes_only:
                    # only calculate orthogonal loss for the activated codes for this batch
                    unique_code_ids = torch.unique(embed_ind)
                    codebook = codebook[:, unique_code_ids]

                num_codes = codebook.shape[1]
                if exists(self.orthogonal_reg_max_codes) and num_codes > self.orthogonal_reg_max_codes:
                    rand_ids = torch.randperm(num_codes, device = device)[:self.orthogonal_reg_max_codes]
                    codebook = codebook[:, rand_ids]

And even with the above code, orthogonal_reg_active_codes_only only work properly when num_codebook = 1

Looking forward to hearing your opinion.

lucidrains commented 1 year ago

@l4zyf9x hey Trinh! thank you for opening this issue! i've neglected this feature when adding multi-headed VQ, and since so few researchers reported good results using it, i haven't bothered fixing it

all your points should be addressed in the latest commit. are you using this successfully in any model?

trfnhle commented 1 year ago

Thank you for fixing the issue. While experimenting with orthogonal reg loss, I do not see it affecting reconstructing loss much besides the effect that makes codebook vectors far from each other.