lucidrains / vector-quantize-pytorch

Vector (and Scalar) Quantization, in Pytorch
MIT License
2.12k stars 179 forks source link

Possible bug in LFQ distance calculation #120

Closed MattMcPartlon closed 1 month ago

MattMcPartlon commented 1 month ago

https://github.com/lucidrains/vector-quantize-pytorch/blob/e244d472054fff7569bf32f80c2fca22928e9c16/vector_quantize_pytorch/lookup_free_quantization.py#L222

Hey Phil, Thanks for another great implementation 😄 . Regarding the distance calculation in LFQ (linked) I think this only holds if you're comparing self.codebook to the quantization of x (i.e. both have constant norm).

import torch
xs = torch.randn(10,3)
ys = torch.randn(10,3)
#xs,ys = map(lambda x: x/torch.norm(x,dim=-1,keepdim=True), (xs,ys))
print(torch.argmin(torch.cdist(xs,ys),dim=-1))
print(torch.argmin(-torch.einsum("i d, c d -> ... i c", xs, ys),dim=-1))

# output:
#tensor([8, 4, 7, 2, 3, 0, 4, 7, 1, 4])
#tensor([3, 9, 3, 3, 3, 2, 5, 9, 9, 5])
import torch
xs = torch.randn(10,3)
ys = torch.randn(10,3)
xs,ys = map(lambda x: x/torch.norm(x,dim=-1,keepdim=True), (xs,ys))
print(torch.argmin(torch.cdist(xs,ys),dim=-1))
print(torch.argmin(-torch.einsum("i d, c d -> ... i c", xs, ys),dim=-1))

# output:
#tensor([4, 7, 7, 7, 4, 7, 7, 8, 8, 8])
#tensor([4, 7, 7, 7, 4, 7, 7, 8, 8, 8])
lucidrains commented 1 month ago

@MattMcPartlon hey Matt 😄 ! yes i think you are right, and digging into it, @Lijun-Yu made the change here a while back https://github.com/lucidrains/vector-quantize-pytorch/pull/89/files#diff-3488501716ef0c4a1b84127592be7e1015e0faf0d7d3b630b69d88e55a134701R213

i can prepare a PR and merge it once he gives the final word

MattMcPartlon commented 1 month ago

He would certainly know :). Thanks, Phil!

lucidrains commented 1 month ago

@MattMcPartlon going to just merge it for now so you can go train your potentially life-saving model :wink:

@Lijun-Yu let me know if you see any issues with this revert

RobertLuo1 commented 1 month ago

I think whether normalization on the original_input and the codebook or not perform the same. In LFQ, the codebook is the combination of {-1, 1}. Therefore, it seems to be equaliviant to the commit loss here. The distance of the same sign between original input and the token in codebook will be the largest.

lucidrains commented 1 month ago

@RobertLuo1 ah yes, you are right 🤦 reverted the revert 😆

that's really neat!

import torch

xs = torch.randn(10,3)
ys = torch.randn(10,3)

ys = ys / ys.norm(dim = -1, keepdim = True)

print(torch.argmin(torch.cdist(xs, ys),dim=-1))
print(torch.argmin(-torch.einsum("i d, c d -> ... i c", xs, ys),dim = -1))

# output:
# tensor([1, 1, 1, 1, 6, 9, 8, 6, 5, 5])
# tensor([1, 1, 1, 1, 6, 9, 8, 6, 5, 5])
MattMcPartlon commented 1 month ago

yup, great catch @RobertLuo1. thank you! (and thanks for the quick (yet unneeded) patch Phil!)

RobertLuo1 commented 1 month ago

@MattMcPartlon huh, You are welcome! I have been trying it for a while. Yet, I still could not exploit its power as I have tried many times I found that the following loss term avg_codebook_entropy can not be optimized to be larger, It seems to be contradicted to the per_sample_entropy term. What about you? @MattMcPartlon

MattMcPartlon commented 1 month ago

I have not had luck yet @RobertLuo1. The per_sample_entropy should be a n_tokens (rows) x n_codewords (cols) matrix. The first term term says that each column should have low entropy. The second term says that if you average over the rows then the entropy should be large (i.e. each codeword should have uniform probability of being sampled for some token ). I agree that these can be somewhat opposing objectives, especially when the number of tokens is much smaller than the number of codewords...

Anyways, LFQ has not worked at all for me compared to the standard VAE objective. Have you tried any other quantizers?

RobertLuo1 commented 1 month ago

@MattMcPartlon Yeah! I see in magvit2 paper, they only utilize 256 token in each batch. The flattened batch operation makes the value of avg_codebook_entropy loss increase along with the batch. LFQ seems does not work for me either. I assume that maybe there is something missing when implementing it.

MattMcPartlon commented 1 month ago

yup, I agree. I can start to overfit with a large enough codebook size (64k), but can't find a setting that generalizes well. It's a shame because I really liked the magvit2 paper 😓.

I'm planning to try out the other quantizers tomorrow. I'll post back if I have any luck :).

MattMcPartlon commented 1 month ago

@RobertLuo1 This is not a very satisfying fix, but I think I've managed to get LFQ working well by cranking down the commitment + entropy loss weights (0.001) and adding another loss term:

aux_loss=torch.mean(
                ((original_input**2 - torch.ones_like(original_input)) ** 2)[mask]
            ),

this should already be enforced in the comitment loss, but it appears to be useful for convergence (still very early in training, but it's the first run I've seen that behaves normally).

The yellow and light blue runs have the same configuration as two of the others, except for the addition of that aux _loss term. the distogram_loss is my proxy for reconstruction loss.

Screenshot 2024-05-05 at 1 34 29 PM
lucidrains commented 1 month ago

another thing I wanted to try is to make the projection before the scalar quantization use cosine similarity with some learned gamma

MattMcPartlon commented 1 month ago

@lucidrains very nice idea. I think some way of forcing the pre-quantized input to sit near [-1,1]^{k} is important. If you don't have this "normalization" then the straight through gradient can diverge significantly from the "true unquantized gradient" (eq 14. With this extra term you ensure that the pre-quantized representation is close to the quantized representation so the decoder gradients should be more informative to the encoder... I could just be talking out my @$$ though 😅 . In theory the commitment loss should do this, but I had no luck.

MattMcPartlon commented 1 month ago

Also, FSQ works out of the box 🤷

lucidrains commented 1 month ago

yea FSQ worked for me on a test drive too! I'd like to figure out LFQ at some point because I think it is perfect for employing genetic algorithms on some latent space

lucidrains commented 1 month ago

I suppose genetic algorithms should work for FSQ too, just have to think about it

RobertLuo1 commented 1 month ago

@MattMcPartlon Thanks for your sharing. Actually My loss behaves normal except for the avg entropy term. Commitment loss and per sample entropy can be optimized and decreases normally. However, I could not figure out how to optimize the avg entropy by just simply aligning to the paper in magvit2. Does decrease the weight of avg entropy will benefit? I find that this term is much larger than any other loss term in my configuration.

@RobertLuo1 This is not a very satisfying fix, but I think I've managed to get LFQ working well by cranking down the commitment + entropy loss weights (0.001) and adding another loss term:

aux_loss=torch.mean(
                ((original_input**2 - torch.ones_like(original_input)) ** 2)[mask]
            ),

this should already be enforced in the comitment loss, but it appears to be useful for convergence (still very early in training, but it's the first run I've seen that behaves normally).

The yellow and light blue runs have the same configuration as two of the others, except for the addition of that aux _loss term. the distogram_loss is my proxy for reconstruction loss.

Screenshot 2024-05-05 at 1 34 29 PM
RobertLuo1 commented 1 month ago

And By the way, I think the aux loss term performs similar to the commitment loss. I have no idea how to train a tokenizer with such low FID.

image
RobertLuo1 commented 1 month ago

@lucidrains I think FSQ can be optimized better than LFQ since it does not have any loss to be optimized. However, I think the performance will be worse than LFQ

MattMcPartlon commented 1 month ago

@RobertLuo1 Ya, I agree it should do the same thing as commitment loss, but even when I set the commitment loss weight to 1.0, I still can't optimize it well. In the plot below, the two runs that converge have auxiliary loss + commitment applied. the others just have standard commitment loss.

Regarding FSQ and LFQ, naively it seems like FSQ is kind of a generalization of LFQ (FSQ with levels = [2,2,...,2] should have the exact same capacity as LFQ ,right?).

Screenshot 2024-05-05 at 8 14 14 PM
RobertLuo1 commented 1 month ago

Yea! if the level is [2,2,2,...] FSQ is same as LFQ but without aux loss. However, I think simply convert FSQ with the level [2,2,2....], the performance will drop heavily since it has no codebook loss to constrain. I think LFQ can enforce the encoder and decoder to learn a strong mapping if trained properly. Thats the reason why it has such low FID and LPIPS. By the way, Have you ever train with imagenet and test its FID?

MisterBourbaki commented 1 month ago

Hi there, Sorry to intrude but I think one possibily is that, as it is implemented and used here, the entropy is concave and may be negative for value larger than one (the prob vector is a collection of distances, and there is no evidence that they will be less than one). So in the end you want to find the minimum of a concave, possibly negative function... here goes trouble :)

One way to counterbalance that would be to replace the sum() by a mean(), as a way to use the possibly large dimension as a weight. This will also make the difference between the "average of entropy" and the "entropy of the average" more consistent.

Hope this is clear...

RobertLuo1 commented 1 month ago

Hi there, Sorry to intrude but I think one possibily is that, as it is implemented and used here, the entropy is concave and may be negative for value larger than one (the prob vector is a collection of distances, and there is no evidence that they will be less than one). So in the end you want to find the minimum of a concave, possibly negative function... here goes trouble :)

One way to counterbalance that would be to replace the sum() by a mean(), as a way to use the possibly large dimension as a weight. This will also make the difference between the "average of entropy" and the "entropy of the average" more consistent.

Hope this is clear... In my experiment, the avg codebook entropy is about 5 while the per sample entropy is about 0.0x, what you mean is that the negative loss larger than one will cause trouble? Moreover, the replacement of sum() -> mean() only performs in the avg entropy calculation (https://github.com/lucidrains/vector-quantize-pytorch/blob/4a643eb4161fdd154d6d04f11c798fde714f4cec/vector_quantize_pytorch/lookup_free_quantization.py#L49) but not for the per sample one?

MisterBourbaki commented 1 month ago

There are three terms in the total loss:

The entropy per sample aims at being zero (non uniform probability distribution) when the (opposite of) the averaged entropy aims at being (negatively large). But then, the some of the two will be negative, and will badly compete with the reconstruction loss: the reconstruction loss can be as large as the averaged entropy, but the difference will be close to zero... which is ultimately the goal of any ML training.

The function https://github.com/lucidrains/vector-quantize-pytorch/blob/4a643eb4161fdd154d6d04f11c798fde714f4cec/vector_quantize_pytorch/lookup_free_quantization.py#L48-L49 should be with mean instead of sum, in all cases. At least to my understanding.

RobertLuo1 commented 1 month ago

Thanks for your reply. I will try that in my experiment later. However, I saw Magvit they still use the sum function instead of the mean. So I think maybe the negative value wont matter a lot https://github.com/google-research/magvit/blob/05e8cfd6559c47955793d70602d62a2f9b0bdef5/videogvt/train_lib/losses.py#L303. But it is utilized in the standard VQGAN cases

MattMcPartlon commented 1 month ago

I think the sum is fine. Either way the values differ by a constant (since the channel dimension is fixed to L=log(codebook_size), using a mean should be equivalent to dividing your loss weights by L). Ultimately the objective will look the same to the model.

The fact that the per-sample entropy is basically zero suggests that this softmax is saturated to a single point. https://github.com/lucidrains/vector-quantize-pytorch/blob/4a643eb4161fdd154d6d04f11c798fde714f4cec/vector_quantize_pytorch/lookup_free_quantization.py#L224

Maybe using a less aggressive value of for inverse temperature (inv_temperature) could help? The distance is just 2negative cosine similarity, which takes values in [L, L-2,...,2-L,-L] when the commitment loss is 0. By default, we scale this by 100, which means that the probability assigned to e.g. [1,1,1,1,1,...] and [-1,1,1,1,...] differ by a factor of exp[200]. In general, the ratio of mass assigned to x v.s. y is `exp(2hamming_distance(x,y)*inv_temperature)`. If the model wants to distribute these probabilities more evenly, it would be forced to increase the commitment loss (i.e. the pre-quantization values must move closer to [0,0,0,0...0] to place mass on neighboring codewords).

lucidrains commented 1 month ago

ok, you all got me interested with this discussion

running experiments today til tomorrow to see whether LFQ can really beat FSQ with all level 2s, at least, on the audio dataset i'm currently looking at

lucidrains commented 1 month ago

8pafkf

lucidrains commented 1 month ago
Screen Shot 2024-05-07 at 10 49 07 AM

yea, FSQ all level 2s work for me without any auxiliary losses, though it lags behind having > 2 levels. this is in a hierarchical vae though with small codebook size (256), so perhaps a bit different

should get some LFQ runs by mid afternoon

lucidrains commented 1 month ago

oh ha, LFQ default settings already diverged after 200 steps

edit: nvm, was just the default commit loss

edit2: @MattMcPartlon your makeshift loss def works better for the commitment loss, at least early in training. recon loss lags behind FSQ with all level 2s though 🤷 . i'll let it run until 20k steps, let it have a fair shot

lucidrains commented 1 month ago
Screen Shot 2024-05-07 at 1 01 19 PM

yea i'm going to stop it early, it doesn't look that great

lucidrains commented 1 month ago

Screenshot from 2024-05-07 17-06-05

looking a bit better with soft clamping to -10 to 10, but still behind FSQ with all level 4s. will let it run for a bit longer this time

lucidrains commented 1 month ago

Screenshot from 2024-05-08 06-48-17

cosine sim projection also works, but still unable to beat FSQ level-4s

however, both are already way more stable than VQ in a hierarchical VAE setting

JohnHerry commented 1 week ago

following this issure