lucidrains / vector-quantize-pytorch

Vector (and Scalar) Quantization, in Pytorch
MIT License
2.37k stars 196 forks source link

EMA update on CosineCodebook #26

Open roomo7time opened 1 year ago

roomo7time commented 1 year ago

The original VIT-VQGAN paper does not seem to use EMA update for codebook learning since their codebook is unit-normalized vectors.

Particularly, to my understanding, EMA update does not quite make sense when the encoder outputs and codebook vectors are unit-normalized ones.

What's your take on this? Should we NOT use EMA update with CosineCodebook?

pengzhangzhi commented 1 year ago

Would you like to explain why ema does not work for the unit-normalized codebook?

Saltychtao commented 1 year ago

I found when using EMA for cosine code book, the l2-norm of the input to the vq module would grow gradually, from 22 -> 20000, leading to growing training loss. Has anyone met this problem?

Saltychtao commented 1 year ago

I found when using EMA for cosine code book, the l2-norm of the input to the vq module would grow gradually, from 22 -> 20000, leading to growing training loss. Has anyone met this problem?

In case anyone else has this problem, I add a layernorm layer after the vq_in projection, and the growing norm problem is largely solved.

jzhang38 commented 1 year ago

@Saltychtao I also encounter a similar issue. Does vq_in refer to VectorQuantize.project_in?

Saltychtao commented 1 year ago

@Saltychtao I also encounter a similar issue. Does vq_in refer to VectorQuantize.project_in?

Yes.

santisy commented 3 months ago

I found when using EMA for cosine code book, the l2-norm of the input to the vq module would grow gradually, from 22 -> 20000, leading to growing training loss. Has anyone met this problem?

In case anyone else has this problem, I add a layernorm layer after the vq_in projection, and the growing norm problem is largely solved.

@Saltychtao Hi, just want to make sure that the current vesion of implementation here seems to put one normalization (l2norm) after the project_in. I also encounter the training loss explosion issue somehow at current version

lucidrains commented 3 months ago

@santisy want to try turning this on (following @Saltychtao 's solution)

let me know if it helps