Open kobiso opened 3 years ago
This config is probably used for inference only. DummyLoss is simply not calculating any loss.
Another thing is that kl_weight is 1e-8
, which makes quantization loss extremely small. I don't think this is correct either.
Hope the author would release the config for training soon. As for now, I'll stick to vanilla VQ although it suffers from index collapsing. The utilization of the 16384 model is ~6% (~1000 valid codes), the 1024 model is around 50% (~500 valid codes). You can visualize the codes by decoding the VQ dictionary.
EDIT: The quantization loss for GumbelVQ is not big, around 0.005 if training from scratch with kl_weight=1
.
I would also like some clarity on the best KL weight for training from scratch (and whether it should be warmed up over time).
@TomoshibiAkira Why do you expect to have a better utilization of the codes with Gumbel Quantization?
@borisdayma Because the codebook in the f=8 GumbelVQ model does not contain invalid codes, unlike the IN model.
By "invalid codes", I mean: In the IN model's codebook, there are several thousands of codes that have a very small L2-norm (around 5e-4) compared with other valid codes (around 15). These codes usually don't contain any interesting information as shown in the visualization (the first 1024 codes in the 16384 model).
Here's the visualization of the first 1024 codes in the f=8 GumbelVQ's codebook for comparsion.
Thanks @TomoshibiAkira for this great explanation! Btw how did you create those visualizations?
You're welcome! @borisdayma I simply treat every code as a 1x1-size patch and forward it through the pretrained decoder.
Thanks for your explanation@TomoshibiAkira. How do you set the temperature_scheduler_config?
@TomoshibiAkira wondering if you looked at codebook utilization for other models (like OpenAI dVAE).
@borisdayma I haven't been playing around VQ for a while, but hey, we're here. Why not :) Here's the visualization of DALL-E's discrete code (the first 1024 codes of 8192 in total).
DALL-E suffers much harsher information loss than VQ. The reason might be that every code in DALL-E is an integer (or "a class"). Thus contains much less information than VQ's codes (feature vectors). As a result, although they're both f=8, DALL-E codes' visualization is much muddier and lacks details than GumbelVQ's codebook.
As for the utilization though, DALL-E's discretion method is different from VQ. For VQ, I could compute the norm of a code, and from that, I could tell whether it's valid since the invalid ones are always very different from the valid ones. For DALL-E, there's no such a way that I can explicitly determine whether the code is valid or not. Every code can be decoded into a patch, and from the look of it, it seems like every code is occupied. Although there are some duplications, it also happens in GumbelVQ's codebook, so one might say DALL-E's codes are 100% utilized.
EDIT: The code for visualization, you can directly use this in the usage.ipynb
provided by DALL-E.
import torch.nn.functional as F
from torchvision.utils import save_image
num_classes = enc.vocab_size
batch_size = 1024
for i in range(num_classes // batch_size):
ind = torch.arange(i * batch_size, (i+1) * batch_size)
z = F.one_hot(ind, num_classes=num_classes).unsqueeze(-1).unsqueeze(-1).float()
x_stats = dec(z.cuda()).float()
x_recs = unmap_pixels(torch.sigmoid(x_stats[:, :3]))
save_image(x_recs, "viz/dic{}.png".format(i), nrow=32)
@sczhou
The default parameter in GumbelVQ's model.yaml
seems okay to me. That may be a good starting point.
Thanks, @TomoshibiAkira. Where could I find GumbelVQ's model.yaml? I didn't see this config file in this repo.
Many thanks.
Thanks, @TomoshibiAkira. Where could I find GumbelVQ's model.yaml? I didn't see this config file in this repo.
Many thanks.
It's in the pretrained model zoo. https://heibox.uni-heidelberg.de/d/2e5662443a6b4307b470/?p=%2F&mode=list
DALL-E suffers much harsher information loss than VQ. The reason might be that every code in DALL-E is an integer (or "a class"). Thus contains much less information than VQ's codes (feature vectors). As a result, although they're both f=8, DALL-E codes' visualization is much muddier and lacks details than GumbelVQ's codebook.
@TomoshibiAkira Don't they both use a codebook where you can use either the codebook index or the corresponding feature vector? I thought that overall the one from OpenAI is blurrier mainly because it averages over patches (with mse loss) vs GAN loss and perceptual loss from the VQGAN force. to sharpen the image.
@TomoshibiAkira Don't they both use a codebook where you can use either the codebook index or the corresponding feature vector?
@borisdayma I personally don't think so.
In the image reconstruction example from usage.ipynb
, the discretion method of DALL-E is the argmax
function.
The output feature of DALL-E's encoder is directly argmaxed in the channel dimension and then transformed into a one-hot vector.
Then the one-hot vector is sent to the decoder and followed by normal Conv2D ops. One can say that DALL-E's decoder is actually decoding the INDEX of the encoder's feature.
To put it in VQ's perspective, you can say all different 8192 one-hot vectors are DALL-E's codebook. VQGAN maps the continuous feature into a codebook that has 8192 codes, every code is an $\mathbf{R}^256$ vector. DALL-E also maps the continuous feature into a codebook that has 8192 codes, but every code is a one-hot vector. The information that one code could represent in these two methods is vastly different IMO.
I thought that overall the one from OpenAI is blurrier mainly because it averages over patches (with mse loss) vs GAN loss and perceptual loss from the VQGAN force. to sharpen the image.
Well the GAN and perceptual loss are definitely helping, I do think even without them the VQGAN (or a plain simple VQ-VAE) could achieve better reconstruction results. Here's one thought, if we keep every other thing of VQGAN intact, and change the codebook into DALL-E style, will it have the same performance? If it does, well, it means that the actual feature DOES NOT matter at all, indices are good enough for feature discretion, which is pretty counterintuitive. But neural networks are pretty counterintuitive as a whole, so yeah :D
@borisdayma I personally don't think so. In the image reconstruction example from
usage.ipynb
, the discretion method of DALL-E is theargmax
function... Here's one thought, if we keep every other thing of VQGAN intact, and change the codebook into DALL-E style, will it have the same performance? If it does, well, it means that the actual feature DOES NOT matter at all, indices are good enough for feature discretion, which is pretty counterintuitive. But neural networks are pretty counterintuitive as a whole, so yeah :D
I think the new Gumbel VQGAN type already has a DALL-E style codebook. It does indeed seem better but I think this comes down to the quantization method preventing codebook collapse. The DALL-E decoder just uses a simple 1x1 conv2d layer to transform the one-hots into feature vectors (it's a one-to-one mapping), I have opened the decoder up and used the features directly instead.
indices are good enough for feature discretion
They have to be because the second stage transformer models only produce indices, not features.
The DALL-E decoder just uses a simple 1x1 conv2d layer to transform the one-hots into feature vectors (it's a one-to-one mapping), I have opened the decoder up and used the features directly instead.
Ah, now I see. If we combine the Conv2D layer with the one-hots (only considering the output of the 1x1 Conv2D layer), it's actually the same as VQ's codebook (with or without Gumbel). The codebook here is actually the weight of the Conv2D layer. Both @borisdayma and you are right. My bad!
Since they're the same,
If it does, well, it means that the actual feature DOES NOT matter at all, indices are good enough for feature discretion.
This hypothesis is invalid from the start.
They have to be because the second stage transformer models only produce indices, not features.
Sorry if my statement is not clear before. What I really want to say is that "if we throw away the contents of the codebook, the indices of the codes alone are good enough to reconstruct the original image". But it doesn't matter now since the hypothesis is not valid at all 🤣
It does indeed seem better but I think this comes down to the quantization method preventing codebook collapse.
I'm not sure at this point. From my personal experience on an AE with VQ, with f=8/f=16, the network's behavior is vastly different from each other on the same dataset. If someone would train an f=8 model without Gumbel to see the codebook utilization, that'll be very helpful.
Hi @TomoshibiAkira , it is really a valuable discussion! May I know if you validate the performance of f=8 without Gumbel? Actually, I just want to see the effect of Gumbel, i.e., adding Gumbel to vanilla VQ will always improve the reconstruction & codebook utilization (e.g., f=8, f=16), or there is some trade-off such as a high utilization of codebook but relatively low accuracy of code matching. If you have any idea on that?
@fnzhan I didn't conduct the experiment so I can't give any concrete answer. Personally, I'd like to believe that Gumbel can improve the performance without any trade-off since it's basically a better method for sampling discrete data. But, I didn't dive deep into the theory part of Gumbel-Softmax so please take a grain of salt.
Hi @TomoshibiAkira , it is really a valuable discussion! May I know if you validate the performance of f=8 without Gumbel? Actually, I just want to see the effect of Gumbel, i.e., adding Gumbel to vanilla VQ will always improve the reconstruction & codebook utilization (e.g., f=8, f=16), or there is some trade-off such as a high utilization of codebook but relatively low accuracy of code matching. If you have any idea on that?
I think the tradeoff is during training, you have to train longer because you have to slowly decrease the Gumbel-Softmax temperature to 0 or very near 0. But I think it is straightforwardly better during inference.
@TomoshibiAkira @crowsonkb Thanks for sharing your insight, I am working on it recently and will update if concrete conclusion is reached.
@fnzhan any updates? Really interested to see if there are any key improvements.
Hi @EmaadKhwaja , I am preparing a paper regarding to it. Here is a brief observation: comparing original VQ and GumbelVQ (both f=16), the improvement with Gumbel tends to be marginal although its codebook utilization is nearly 100%.
@fnzhan Hmm, that's interesting! This might mean that the actual usage of the codes is very unbalanced no matter the codebook utilization (e.g., the network tends to use several "special" codes rather than others), which unfortunately means that the codebook collapse issue is still very much present. The statistics on the indexes of used codes would be helpful to verify this statement. Anyway, good luck with the paper!
@borisdayma I haven't been playing around VQ for a while, but hey, we're here. Why not :) Here's the visualization of DALL-E's discrete code (the first 1024 codes of 8192 in total).
DALL-E suffers much harsher information loss than VQ. The reason might be that every code in DALL-E is an integer (or "a class"). Thus contains much less information than VQ's codes (feature vectors). As a result, although they're both f=8, DALL-E codes' visualization is much muddier and lacks details than GumbelVQ's codebook.
As for the utilization though, DALL-E's discretion method is different from VQ. For VQ, I could compute the norm of a code, and from that, I could tell whether it's valid since the invalid ones are always very different from the valid ones. For DALL-E, there's no such a way that I can explicitly determine whether the code is valid or not. Every code can be decoded into a patch, and from the look of it, it seems like every code is occupied. Although there are some duplications, it also happens in GumbelVQ's codebook, so one might say DALL-E's codes are 100% utilized.
EDIT: The code for visualization, you can directly use this in the
usage.ipynb
provided by DALL-E.import torch.nn.functional as F from torchvision.utils import save_image num_classes = enc.vocab_size batch_size = 1024 for i in range(num_classes // batch_size): ind = torch.arange(i * batch_size, (i+1) * batch_size) z = F.one_hot(ind, num_classes=num_classes).unsqueeze(-1).unsqueeze(-1).float() x_stats = dec(z.cuda()).float() x_recs = unmap_pixels(torch.sigmoid(x_stats[:, :3])) save_image(x_recs, "viz/dic{}.png".format(i), nrow=32)
@sczhou The default parameter in GumbelVQ's
model.yaml
seems okay to me. That may be a good starting point.
Thanks for your share
hey guys, if you are still interested about optimize codebooks. i tried using the codebook with projection and l2 norm from https://arxiv.org/pdf/2110.04627v3.pdf, it works well. here is a codebook. it has great color and various shapes.
@TomoshibiAkira @crowsonkb Thanks for sharing your insight, I am working on it recently and will update if concrete conclusion is reached.
@fnzhan Congratulations on your article being accepted by CVPR 2023! Would you kindly share your codes and pre-trained weights? It would help us to better understand and follow up on your work.
Hi, guys, I have tried train another VQModel (first stage) on my own datasets, (modified the encoder, decoder a little), however when training, the vector quantization loss rises, and kl loss also rises, any suggestions?
Thank you for the great work!
I tried to repoduce
VQGAN OpenImages (f=8), 8192, GumbelQuantization
model based on the config file from the cloud. (the detailed config file is in below.)VQGAN OpenImages (f=8), 8192, GumbelQuantization
``` model: base_learning_rate: 4.5e-06 target: taming.models.vqgan.GumbelVQ params: kl_weight: 1.0e-08 embed_dim: 256 n_embed: 8192 monitor: val/rec_loss temperature_scheduler_config: target: taming.lr_scheduler.LambdaWarmUpCosineScheduler params: warm_up_steps: 0 max_decay_steps: 1000001 lr_start: 0.9 lr_max: 0.9 lr_min: 1.0e-06 ddconfig: double_z: false z_channels: 256 resolution: 256 in_channels: 3 out_ch: 3 ch: 128 ch_mult: - 1 - 1 - 2 - 4 num_res_blocks: 2 attn_resolutions: - 32 dropout: 0.0 lossconfig: target: taming.modules.losses.vqperceptual.DummyLoss ```However, I encountered some errors to train with GumbelQuantization training. The first error was an unexpected keyword argument error as below.
I could fix this error by remove
return_pred_indices=True
from the below line. https://github.com/CompVis/taming-transformers/blob/9d17ea64b820f7633ea6b8823e1f78729447cb57/taming/models/vqgan.py#L336The second error occurs because of
DummyLoss
as below.This can be fixed by changing
target: taming.modules.losses.vqperceptual.DummyLoss
totarget: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
.But the thing is, I not sure if
VQGAN OpenImages (f=8), 8192, GumbelQuantization
model was trained with Discriminator loss and when it was on with what parameters. Can you share the detailed config file ofVQGAN OpenImages (f=8), 8192, GumbelQuantization
model and fix above issues so that the model can be reproducible?Thank you in advance!