lucidrains / vector-quantize-pytorch

Vector (and Scalar) Quantization, in Pytorch
MIT License
2.66k stars 216 forks source link

LatentQuantize exploding loss #151

Open jbmaxwell opened 3 months ago

jbmaxwell commented 3 months ago

I'm trying to use the LatentQuantize model in an autoencoder context. My inputs are flat 1-d tensors (32) and my encoder passes a shape of (batch_size, 64) to the quantizer. For now, my "levels" is [8, 6, 4], my latent_dim is 64:

self.lq = LatentQuantize(
            levels=levels,
            dim=latent_dim,
            commitment_loss_weight=0.1,
            quantization_loss_weight=0.1,
        )

The loss starts at zero, then exponentially increases:

Screenshot 2024-08-01 at 1 53 43 PM

Any thoughts as to why this might happen?

lucidrains commented 3 months ago

someone actually pull requested this in and i'm unfamiliar with it

does the reconstruction loss look good?

jbmaxwell commented 3 months ago

Well, actually, I've just discovered that it seems to be an LR thing...

Screenshot 2024-08-01 at 2 29 56 PM

zooming in on first 10k steps:

Screenshot 2024-08-01 at 2 29 26 PM

But recon seems to converge pretty steadily. So maybe just a false alarm. I still have to wrap my head around how LatentQuantize works (and how to get what I want), mind you!