lucidrains / vector-quantize-pytorch

Vector (and Scalar) Quantization, in Pytorch
MIT License
2.65k stars 215 forks source link

Seeking clarifications regarding learnable codebook #147

Open anupsingh15 opened 4 months ago

anupsingh15 commented 4 months ago

Hi,

I am interested in learning codewords (not using EMA) that are L2-normalized and orthonormal with each other. To do so, I created the vector quantizer using the following configuration:

vq = VectorQuantize(
    dim = 256,
    codebook_size = 256,
    use_cosine_sim = True,                
    orthogonal_reg_weight = 10,                 
    orthogonal_reg_max_codes = 128,             
    orthogonal_reg_active_codes_only = False,
    learnable_codebook=True,
    ema_update=False
)

However, I noticed in the implementation at line 1071 that there is only a single term that enforces input embedding to push towards their corresponding quantized (codeword) embeddings. It does not include a second term that would enforce the other way round. Am I missing something here?

Also, if I create a vector quantizer that learns codebook using EMA with the following configuration:

vq = VectorQuantize(
    dim = 256,
    codebook_size = 256,
    use_cosine_sim = True,                 
    orthogonal_reg_weight = 10,                 
    orthogonal_reg_max_codes = 128,             
    orthogonal_reg_active_codes_only = False,
    learnable_codebook=False,
    ema_update=True,
   decay=0.8
)

Will it still learn codewords to ensure their orthonormalilty?