Closed MattMcPartlon closed 1 month ago
@MattMcPartlon hey Matt 😄 ! yes i think you are right, and digging into it, @Lijun-Yu made the change here a while back https://github.com/lucidrains/vector-quantize-pytorch/pull/89/files#diff-3488501716ef0c4a1b84127592be7e1015e0faf0d7d3b630b69d88e55a134701R213
i can prepare a PR and merge it once he gives the final word
He would certainly know :). Thanks, Phil!
@MattMcPartlon going to just merge it for now so you can go train your potentially life-saving model :wink:
@Lijun-Yu let me know if you see any issues with this revert
I think whether normalization on the original_input and the codebook or not perform the same. In LFQ, the codebook is the combination of {-1, 1}. Therefore, it seems to be equaliviant to the commit loss here. The distance of the same sign between original input and the token in codebook will be the largest.
@RobertLuo1 ah yes, you are right 🤦 reverted the revert 😆
that's really neat!
import torch
xs = torch.randn(10,3)
ys = torch.randn(10,3)
ys = ys / ys.norm(dim = -1, keepdim = True)
print(torch.argmin(torch.cdist(xs, ys),dim=-1))
print(torch.argmin(-torch.einsum("i d, c d -> ... i c", xs, ys),dim = -1))
# output:
# tensor([1, 1, 1, 1, 6, 9, 8, 6, 5, 5])
# tensor([1, 1, 1, 1, 6, 9, 8, 6, 5, 5])
yup, great catch @RobertLuo1. thank you! (and thanks for the quick (yet unneeded) patch Phil!)
@MattMcPartlon huh, You are welcome! I have been trying it for a while. Yet, I still could not exploit its power as I have tried many times I found that the following loss term avg_codebook_entropy can not be optimized to be larger, It seems to be contradicted to the per_sample_entropy term. What about you? @MattMcPartlon
I have not had luck yet @RobertLuo1. The per_sample_entropy should be a n_tokens (rows) x n_codewords (cols) matrix. The first term term says that each column should have low entropy. The second term says that if you average over the rows then the entropy should be large (i.e. each codeword should have uniform probability of being sampled for some token ). I agree that these can be somewhat opposing objectives, especially when the number of tokens is much smaller than the number of codewords...
Anyways, LFQ has not worked at all for me compared to the standard VAE objective. Have you tried any other quantizers?
@MattMcPartlon Yeah! I see in magvit2 paper, they only utilize 256 token in each batch. The flattened batch operation makes the value of avg_codebook_entropy loss increase along with the batch. LFQ seems does not work for me either. I assume that maybe there is something missing when implementing it.
yup, I agree. I can start to overfit with a large enough codebook size (64k), but can't find a setting that generalizes well. It's a shame because I really liked the magvit2 paper 😓.
I'm planning to try out the other quantizers tomorrow. I'll post back if I have any luck :).
@RobertLuo1 This is not a very satisfying fix, but I think I've managed to get LFQ working well by cranking down the commitment + entropy loss weights (0.001) and adding another loss term:
aux_loss=torch.mean(
((original_input**2 - torch.ones_like(original_input)) ** 2)[mask]
),
this should already be enforced in the comitment loss, but it appears to be useful for convergence (still very early in training, but it's the first run I've seen that behaves normally).
The yellow and light blue runs have the same configuration as two of the others, except for the addition of that aux _loss term. the distogram_loss
is my proxy for reconstruction loss.
another thing I wanted to try is to make the projection before the scalar quantization use cosine similarity with some learned gamma
@lucidrains very nice idea. I think some way of forcing the pre-quantized input to sit near [-1,1]^{k} is important. If you don't have this "normalization" then the straight through gradient can diverge significantly from the "true unquantized gradient" (eq 14. With this extra term you ensure that the pre-quantized representation is close to the quantized representation so the decoder gradients should be more informative to the encoder... I could just be talking out my @$$ though 😅 . In theory the commitment loss should do this, but I had no luck.
Also, FSQ works out of the box 🤷
yea FSQ worked for me on a test drive too! I'd like to figure out LFQ at some point because I think it is perfect for employing genetic algorithms on some latent space
I suppose genetic algorithms should work for FSQ too, just have to think about it
@MattMcPartlon Thanks for your sharing. Actually My loss behaves normal except for the avg entropy term. Commitment loss and per sample entropy can be optimized and decreases normally. However, I could not figure out how to optimize the avg entropy by just simply aligning to the paper in magvit2. Does decrease the weight of avg entropy will benefit? I find that this term is much larger than any other loss term in my configuration.
@RobertLuo1 This is not a very satisfying fix, but I think I've managed to get LFQ working well by cranking down the commitment + entropy loss weights (0.001) and adding another loss term:
aux_loss=torch.mean( ((original_input**2 - torch.ones_like(original_input)) ** 2)[mask] ),
this should already be enforced in the comitment loss, but it appears to be useful for convergence (still very early in training, but it's the first run I've seen that behaves normally).
The yellow and light blue runs have the same configuration as two of the others, except for the addition of that aux _loss term. the
distogram_loss
is my proxy for reconstruction loss.![]()
And By the way, I think the aux loss term performs similar to the commitment loss. I have no idea how to train a tokenizer with such low FID.
@lucidrains I think FSQ can be optimized better than LFQ since it does not have any loss to be optimized. However, I think the performance will be worse than LFQ
@RobertLuo1 Ya, I agree it should do the same thing as commitment loss, but even when I set the commitment loss weight to 1.0, I still can't optimize it well. In the plot below, the two runs that converge have auxiliary loss + commitment applied. the others just have standard commitment loss.
Regarding FSQ and LFQ, naively it seems like FSQ is kind of a generalization of LFQ (FSQ with levels = [2,2,...,2] should have the exact same capacity as LFQ ,right?).
Yea! if the level is [2,2,2,...] FSQ is same as LFQ but without aux loss. However, I think simply convert FSQ with the level [2,2,2....], the performance will drop heavily since it has no codebook loss to constrain. I think LFQ can enforce the encoder and decoder to learn a strong mapping if trained properly. Thats the reason why it has such low FID and LPIPS. By the way, Have you ever train with imagenet and test its FID?
Hi there, Sorry to intrude but I think one possibily is that, as it is implemented and used here, the entropy is concave and may be negative for value larger than one (the prob vector is a collection of distances, and there is no evidence that they will be less than one). So in the end you want to find the minimum of a concave, possibly negative function... here goes trouble :)
One way to counterbalance that would be to replace the sum() by a mean(), as a way to use the possibly large dimension as a weight. This will also make the difference between the "average of entropy" and the "entropy of the average" more consistent.
Hope this is clear...
Hi there, Sorry to intrude but I think one possibily is that, as it is implemented and used here, the entropy is concave and may be negative for value larger than one (the prob vector is a collection of distances, and there is no evidence that they will be less than one). So in the end you want to find the minimum of a concave, possibly negative function... here goes trouble :)
One way to counterbalance that would be to replace the sum() by a mean(), as a way to use the possibly large dimension as a weight. This will also make the difference between the "average of entropy" and the "entropy of the average" more consistent.
Hope this is clear... In my experiment, the avg codebook entropy is about 5 while the per sample entropy is about 0.0x, what you mean is that the negative loss larger than one will cause trouble? Moreover, the replacement of sum() -> mean() only performs in the avg entropy calculation (https://github.com/lucidrains/vector-quantize-pytorch/blob/4a643eb4161fdd154d6d04f11c798fde714f4cec/vector_quantize_pytorch/lookup_free_quantization.py#L49) but not for the per sample one?
There are three terms in the total loss:
The entropy per sample aims at being zero (non uniform probability distribution) when the (opposite of) the averaged entropy aims at being (negatively large). But then, the some of the two will be negative, and will badly compete with the reconstruction loss: the reconstruction loss can be as large as the averaged entropy, but the difference will be close to zero... which is ultimately the goal of any ML training.
The function https://github.com/lucidrains/vector-quantize-pytorch/blob/4a643eb4161fdd154d6d04f11c798fde714f4cec/vector_quantize_pytorch/lookup_free_quantization.py#L48-L49 should be with mean instead of sum, in all cases. At least to my understanding.
Thanks for your reply. I will try that in my experiment later. However, I saw Magvit they still use the sum function instead of the mean. So I think maybe the negative value wont matter a lot https://github.com/google-research/magvit/blob/05e8cfd6559c47955793d70602d62a2f9b0bdef5/videogvt/train_lib/losses.py#L303. But it is utilized in the standard VQGAN cases
I think the sum is fine. Either way the values differ by a constant (since the channel dimension is fixed to L=log(codebook_size)
, using a mean should be equivalent to dividing your loss weights by L). Ultimately the objective will look the same to the model.
The fact that the per-sample entropy is basically zero suggests that this softmax is saturated to a single point. https://github.com/lucidrains/vector-quantize-pytorch/blob/4a643eb4161fdd154d6d04f11c798fde714f4cec/vector_quantize_pytorch/lookup_free_quantization.py#L224
Maybe using a less aggressive value of for inverse temperature (inv_temperature
) could help? The distance is just 2negative cosine similarity, which takes values in [L, L-2,...,2-L,-L]
when the commitment loss is 0. By default, we scale this by 100, which means that the probability assigned to e.g. [1,1,1,1,1,...] and [-1,1,1,1,...] differ by a factor of exp[200]. In general, the ratio of mass assigned to x v.s. y is `exp(2hamming_distance(x,y)*inv_temperature)`. If the model wants to distribute these probabilities more evenly, it would be forced to increase the commitment loss (i.e. the pre-quantization values must move closer to [0,0,0,0...0] to place mass on neighboring codewords).
ok, you all got me interested with this discussion
running experiments today til tomorrow to see whether LFQ can really beat FSQ with all level 2s, at least, on the audio dataset i'm currently looking at
yea, FSQ all level 2s work for me without any auxiliary losses, though it lags behind having > 2 levels. this is in a hierarchical vae though with small codebook size (256), so perhaps a bit different
should get some LFQ runs by mid afternoon
oh ha, LFQ default settings already diverged after 200 steps
edit: nvm, was just the default commit loss
edit2: @MattMcPartlon your makeshift loss def works better for the commitment loss, at least early in training. recon loss lags behind FSQ with all level 2s though 🤷 . i'll let it run until 20k steps, let it have a fair shot
yea i'm going to stop it early, it doesn't look that great
looking a bit better with soft clamping to -10 to 10, but still behind FSQ with all level 4s. will let it run for a bit longer this time
cosine sim projection also works, but still unable to beat FSQ level-4s
however, both are already way more stable than VQ in a hierarchical VAE setting
following this issure
https://github.com/lucidrains/vector-quantize-pytorch/blob/e244d472054fff7569bf32f80c2fca22928e9c16/vector_quantize_pytorch/lookup_free_quantization.py#L222
Hey Phil, Thanks for another great implementation 😄 . Regarding the distance calculation in LFQ (linked) I think this only holds if you're comparing self.codebook to the quantization of x (i.e. both have constant norm).