lucidrains / vector-quantize-pytorch

Vector (and Scalar) Quantization, in Pytorch
MIT License
2.37k stars 196 forks source link

Implementation of LFQ #78

Closed Lijun-Yu closed 10 months ago

Lijun-Yu commented 10 months ago

Hi,

Thank you for the quick integration of LFQ while we work on the official release! I'd like to clarify two points regarding the details to avoid confusion of users:

(1) By default we do not use an activation function - https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/lookup_free_quantization.py#L57 (2) We use the commitment loss as in VQVAE - https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/vector_quantize_pytorch.py#L987

These do not change anything at inference time given the simple nature of LFQ, but do affect the training performance. For the binary case, an activation function isn't actually needed, as denoted in the formulas in our paper. These are also key differences between LFQ and FSQ - https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/lookup_free_quantization.py#L5

Thank you!

Lijun-Yu commented 10 months ago

A third point:

(3) We use global or sub-groups to estimate the codebook entropy, which is different from the bitwise entropy - https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/lookup_free_quantization.py#L183

Lijun-Yu commented 10 months ago

Here is a reference that hopefully helps: https://github.com/google-research/magvit/blob/main/videogvt/models/vqvae.py#L44

Compared to this VectorQuantizer class, in LFQ quantized becomes jnp.where(x > 0, 1, -1) and there is no q_latent_loss. The codebook entropy term stays the same as VQ for codebooks <= 2^18, and is factorized into multiple subgroups for larger ones.

jeasinema commented 10 months ago

not sure if this is relevant, but when I run LFQ with trainining=True, it seems that the entropy loss is always the same. Maybe we should not compute bit-wise entropy?

lucidrains commented 10 months ago

@Lijun-Yu Hi Lijun! Thanks for opening this issue and for your new magvit2 paper

I will get this all resolved by sundown

lucidrains commented 10 months ago

@Lijun-Yu what surprises me the most is the lack of activation. what did you see when you tried it with a tanh or sigmoid?

lucidrains commented 10 months ago

@Lijun-Yu ok, all the points should be addressed, thank you for the code review

excited to see the magvit2 tokenizer replicated in the wild

lucidrains commented 10 months ago

not sure if this is relevant, but when I run LFQ with trainining=True, it seems that the entropy loss is always the same. Maybe we should not compute bit-wise entropy?

i've updated the entropy calculation, if you'd like to give it another shot

theAdamColton commented 10 months ago

@lucidrains Should aux loss be returned here, or just entropy aux loss?

https://github.com/lucidrains/vector-quantize-pytorch/blob/8b92bebb1961a27c1a4a6fde34e321e73729bbd4/vector_quantize_pytorch/lookup_free_quantization.py#L251C35-L251C51

lucidrains commented 10 months ago

@theAdamColton 🤦 should be fixed! thanks!

lucidrains commented 10 months ago

all points should be addressed

Mddct commented 1 month ago

A third point:

(3) We use global or sub-groups to estimate the codebook entropy, which is different from the bitwise entropy - https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/lookup_free_quantization.py#L183

How to calculate entropy loss in lfq, I don't know how to calculate 'distance' and probs