CompVis / taming-transformers

Taming Transformers for High-Resolution Image Synthesis
https://arxiv.org/abs/2012.09841
MIT License
5.82k stars 1.15k forks source link

About GumbelQuantization training #67

Open kobiso opened 3 years ago

kobiso commented 3 years ago

Thank you for the great work!

I tried to repoduce VQGAN OpenImages (f=8), 8192, GumbelQuantization model based on the config file from the cloud. (the detailed config file is in below.)

VQGAN OpenImages (f=8), 8192, GumbelQuantization ``` model: base_learning_rate: 4.5e-06 target: taming.models.vqgan.GumbelVQ params: kl_weight: 1.0e-08 embed_dim: 256 n_embed: 8192 monitor: val/rec_loss temperature_scheduler_config: target: taming.lr_scheduler.LambdaWarmUpCosineScheduler params: warm_up_steps: 0 max_decay_steps: 1000001 lr_start: 0.9 lr_max: 0.9 lr_min: 1.0e-06 ddconfig: double_z: false z_channels: 256 resolution: 256 in_channels: 3 out_ch: 3 ch: 128 ch_mult: - 1 - 1 - 2 - 4 num_res_blocks: 2 attn_resolutions: - 32 dropout: 0.0 lossconfig: target: taming.modules.losses.vqperceptual.DummyLoss ```

However, I encountered some errors to train with GumbelQuantization training. The first error was an unexpected keyword argument error as below.

  File "/opt/conda/envs/taming/lib/python3.8/site-packages/pytorch_lightning/overrides/base.py", line 57, in forward
    output = self.module.validation_step(*inputs, **kwargs)
  File "/home/shared/workspace/dalle/taming-transformers/taming/models/vqgan.py", line 336, in validation_step
    xrec, qloss = self(x, return_pred_indices=True)
  File "/opt/conda/envs/taming/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
TypeError: forward() got an unexpected keyword argument 'return_pred_indices'

I could fix this error by remove return_pred_indices=True from the below line. https://github.com/CompVis/taming-transformers/blob/9d17ea64b820f7633ea6b8823e1f78729447cb57/taming/models/vqgan.py#L336

The second error occurs because of DummyLoss as below.

  File "/opt/conda/envs/taming/lib/python3.8/site-packages/pytorch_lightning/trainer/optimizers.py", line 34, in init_optimizers
    optim_conf = model.configure_optimizers()
  File "/home/shared/workspace/dalle/taming-transformers/taming/models/vqgan.py", line 129, in configure_optimizers
    opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
  File "/opt/conda/envs/taming/lib/python3.8/site-packages/torch/nn/modules/module.py", line 947, in __getattr__
    raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'DummyLoss' object has no attribute 'discriminator'

This can be fixed by changing target: taming.modules.losses.vqperceptual.DummyLoss to target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator.

But the thing is, I not sure if VQGAN OpenImages (f=8), 8192, GumbelQuantization model was trained with Discriminator loss and when it was on with what parameters. Can you share the detailed config file of VQGAN OpenImages (f=8), 8192, GumbelQuantization model and fix above issues so that the model can be reproducible?

Thank you in advance!

TomoshibiAkira commented 3 years ago

This config is probably used for inference only. DummyLoss is simply not calculating any loss. Another thing is that kl_weight is 1e-8, which makes quantization loss extremely small. I don't think this is correct either.

Hope the author would release the config for training soon. As for now, I'll stick to vanilla VQ although it suffers from index collapsing. The utilization of the 16384 model is ~6% (~1000 valid codes), the 1024 model is around 50% (~500 valid codes). You can visualize the codes by decoding the VQ dictionary.

EDIT: The quantization loss for GumbelVQ is not big, around 0.005 if training from scratch with kl_weight=1.

crowsonkb commented 3 years ago

I would also like some clarity on the best KL weight for training from scratch (and whether it should be warmed up over time).

borisdayma commented 3 years ago

@TomoshibiAkira Why do you expect to have a better utilization of the codes with Gumbel Quantization?

TomoshibiAkira commented 3 years ago

@borisdayma Because the codebook in the f=8 GumbelVQ model does not contain invalid codes, unlike the IN model.

By "invalid codes", I mean: In the IN model's codebook, there are several thousands of codes that have a very small L2-norm (around 5e-4) compared with other valid codes (around 15). These codes usually don't contain any interesting information as shown in the visualization (the first 1024 codes in the 16384 model). image

Here's the visualization of the first 1024 codes in the f=8 GumbelVQ's codebook for comparsion. image

borisdayma commented 3 years ago

Thanks @TomoshibiAkira for this great explanation! Btw how did you create those visualizations?

TomoshibiAkira commented 3 years ago

You're welcome! @borisdayma I simply treat every code as a 1x1-size patch and forward it through the pretrained decoder.

sczhou commented 3 years ago

Thanks for your explanation@TomoshibiAkira. How do you set the temperature_scheduler_config?

borisdayma commented 3 years ago

@TomoshibiAkira wondering if you looked at codebook utilization for other models (like OpenAI dVAE).

TomoshibiAkira commented 3 years ago

@borisdayma I haven't been playing around VQ for a while, but hey, we're here. Why not :) Here's the visualization of DALL-E's discrete code (the first 1024 codes of 8192 in total).

dic0

DALL-E suffers much harsher information loss than VQ. The reason might be that every code in DALL-E is an integer (or "a class"). Thus contains much less information than VQ's codes (feature vectors). As a result, although they're both f=8, DALL-E codes' visualization is much muddier and lacks details than GumbelVQ's codebook.

As for the utilization though, DALL-E's discretion method is different from VQ. For VQ, I could compute the norm of a code, and from that, I could tell whether it's valid since the invalid ones are always very different from the valid ones. For DALL-E, there's no such a way that I can explicitly determine whether the code is valid or not. Every code can be decoded into a patch, and from the look of it, it seems like every code is occupied. Although there are some duplications, it also happens in GumbelVQ's codebook, so one might say DALL-E's codes are 100% utilized.

EDIT: The code for visualization, you can directly use this in the usage.ipynb provided by DALL-E.

import torch.nn.functional as F
from torchvision.utils import save_image
num_classes = enc.vocab_size
batch_size = 1024
for i in range(num_classes // batch_size):
    ind = torch.arange(i * batch_size, (i+1) * batch_size)
    z = F.one_hot(ind, num_classes=num_classes).unsqueeze(-1).unsqueeze(-1).float()
    x_stats = dec(z.cuda()).float()
    x_recs = unmap_pixels(torch.sigmoid(x_stats[:, :3]))
    save_image(x_recs, "viz/dic{}.png".format(i), nrow=32)

@sczhou The default parameter in GumbelVQ's model.yaml seems okay to me. That may be a good starting point.

sczhou commented 3 years ago

Thanks, @TomoshibiAkira. Where could I find GumbelVQ's model.yaml? I didn't see this config file in this repo.

Many thanks.

TomoshibiAkira commented 3 years ago

Thanks, @TomoshibiAkira. Where could I find GumbelVQ's model.yaml? I didn't see this config file in this repo.

Many thanks.

It's in the pretrained model zoo. https://heibox.uni-heidelberg.de/d/2e5662443a6b4307b470/?p=%2F&mode=list

borisdayma commented 3 years ago

DALL-E suffers much harsher information loss than VQ. The reason might be that every code in DALL-E is an integer (or "a class"). Thus contains much less information than VQ's codes (feature vectors). As a result, although they're both f=8, DALL-E codes' visualization is much muddier and lacks details than GumbelVQ's codebook.

@TomoshibiAkira Don't they both use a codebook where you can use either the codebook index or the corresponding feature vector? I thought that overall the one from OpenAI is blurrier mainly because it averages over patches (with mse loss) vs GAN loss and perceptual loss from the VQGAN force. to sharpen the image.

TomoshibiAkira commented 3 years ago

@TomoshibiAkira Don't they both use a codebook where you can use either the codebook index or the corresponding feature vector?

@borisdayma I personally don't think so. In the image reconstruction example from usage.ipynb, the discretion method of DALL-E is the argmax function. The output feature of DALL-E's encoder is directly argmaxed in the channel dimension and then transformed into a one-hot vector. Then the one-hot vector is sent to the decoder and followed by normal Conv2D ops. One can say that DALL-E's decoder is actually decoding the INDEX of the encoder's feature.

To put it in VQ's perspective, you can say all different 8192 one-hot vectors are DALL-E's codebook. VQGAN maps the continuous feature into a codebook that has 8192 codes, every code is an $\mathbf{R}^256$ vector. DALL-E also maps the continuous feature into a codebook that has 8192 codes, but every code is a one-hot vector. The information that one code could represent in these two methods is vastly different IMO.

I thought that overall the one from OpenAI is blurrier mainly because it averages over patches (with mse loss) vs GAN loss and perceptual loss from the VQGAN force. to sharpen the image.

Well the GAN and perceptual loss are definitely helping, I do think even without them the VQGAN (or a plain simple VQ-VAE) could achieve better reconstruction results. Here's one thought, if we keep every other thing of VQGAN intact, and change the codebook into DALL-E style, will it have the same performance? If it does, well, it means that the actual feature DOES NOT matter at all, indices are good enough for feature discretion, which is pretty counterintuitive. But neural networks are pretty counterintuitive as a whole, so yeah :D

crowsonkb commented 3 years ago

@borisdayma I personally don't think so. In the image reconstruction example from usage.ipynb, the discretion method of DALL-E is the argmax function... Here's one thought, if we keep every other thing of VQGAN intact, and change the codebook into DALL-E style, will it have the same performance? If it does, well, it means that the actual feature DOES NOT matter at all, indices are good enough for feature discretion, which is pretty counterintuitive. But neural networks are pretty counterintuitive as a whole, so yeah :D

I think the new Gumbel VQGAN type already has a DALL-E style codebook. It does indeed seem better but I think this comes down to the quantization method preventing codebook collapse. The DALL-E decoder just uses a simple 1x1 conv2d layer to transform the one-hots into feature vectors (it's a one-to-one mapping), I have opened the decoder up and used the features directly instead.

indices are good enough for feature discretion

They have to be because the second stage transformer models only produce indices, not features.

TomoshibiAkira commented 3 years ago

The DALL-E decoder just uses a simple 1x1 conv2d layer to transform the one-hots into feature vectors (it's a one-to-one mapping), I have opened the decoder up and used the features directly instead.

Ah, now I see. If we combine the Conv2D layer with the one-hots (only considering the output of the 1x1 Conv2D layer), it's actually the same as VQ's codebook (with or without Gumbel). The codebook here is actually the weight of the Conv2D layer. Both @borisdayma and you are right. My bad!

Since they're the same,

If it does, well, it means that the actual feature DOES NOT matter at all, indices are good enough for feature discretion.

This hypothesis is invalid from the start.

They have to be because the second stage transformer models only produce indices, not features.

Sorry if my statement is not clear before. What I really want to say is that "if we throw away the contents of the codebook, the indices of the codes alone are good enough to reconstruct the original image". But it doesn't matter now since the hypothesis is not valid at all 🤣

It does indeed seem better but I think this comes down to the quantization method preventing codebook collapse.

I'm not sure at this point. From my personal experience on an AE with VQ, with f=8/f=16, the network's behavior is vastly different from each other on the same dataset. If someone would train an f=8 model without Gumbel to see the codebook utilization, that'll be very helpful.

fnzhan commented 2 years ago

Hi @TomoshibiAkira , it is really a valuable discussion! May I know if you validate the performance of f=8 without Gumbel? Actually, I just want to see the effect of Gumbel, i.e., adding Gumbel to vanilla VQ will always improve the reconstruction & codebook utilization (e.g., f=8, f=16), or there is some trade-off such as a high utilization of codebook but relatively low accuracy of code matching. If you have any idea on that?

TomoshibiAkira commented 2 years ago

@fnzhan I didn't conduct the experiment so I can't give any concrete answer. Personally, I'd like to believe that Gumbel can improve the performance without any trade-off since it's basically a better method for sampling discrete data. But, I didn't dive deep into the theory part of Gumbel-Softmax so please take a grain of salt.

crowsonkb commented 2 years ago

Hi @TomoshibiAkira , it is really a valuable discussion! May I know if you validate the performance of f=8 without Gumbel? Actually, I just want to see the effect of Gumbel, i.e., adding Gumbel to vanilla VQ will always improve the reconstruction & codebook utilization (e.g., f=8, f=16), or there is some trade-off such as a high utilization of codebook but relatively low accuracy of code matching. If you have any idea on that?

I think the tradeoff is during training, you have to train longer because you have to slowly decrease the Gumbel-Softmax temperature to 0 or very near 0. But I think it is straightforwardly better during inference.

fnzhan commented 2 years ago

@TomoshibiAkira @crowsonkb Thanks for sharing your insight, I am working on it recently and will update if concrete conclusion is reached.

EmaadKhwaja commented 2 years ago

@fnzhan any updates? Really interested to see if there are any key improvements.

fnzhan commented 2 years ago

Hi @EmaadKhwaja , I am preparing a paper regarding to it. Here is a brief observation: comparing original VQ and GumbelVQ (both f=16), the improvement with Gumbel tends to be marginal although its codebook utilization is nearly 100%.

TomoshibiAkira commented 2 years ago

@fnzhan Hmm, that's interesting! This might mean that the actual usage of the codes is very unbalanced no matter the codebook utilization (e.g., the network tends to use several "special" codes rather than others), which unfortunately means that the codebook collapse issue is still very much present. The statistics on the indexes of used codes would be helpful to verify this statement. Anyway, good luck with the paper!

Zyriix commented 1 year ago

@borisdayma I haven't been playing around VQ for a while, but hey, we're here. Why not :) Here's the visualization of DALL-E's discrete code (the first 1024 codes of 8192 in total).

dic0

DALL-E suffers much harsher information loss than VQ. The reason might be that every code in DALL-E is an integer (or "a class"). Thus contains much less information than VQ's codes (feature vectors). As a result, although they're both f=8, DALL-E codes' visualization is much muddier and lacks details than GumbelVQ's codebook.

As for the utilization though, DALL-E's discretion method is different from VQ. For VQ, I could compute the norm of a code, and from that, I could tell whether it's valid since the invalid ones are always very different from the valid ones. For DALL-E, there's no such a way that I can explicitly determine whether the code is valid or not. Every code can be decoded into a patch, and from the look of it, it seems like every code is occupied. Although there are some duplications, it also happens in GumbelVQ's codebook, so one might say DALL-E's codes are 100% utilized.

EDIT: The code for visualization, you can directly use this in the usage.ipynb provided by DALL-E.

import torch.nn.functional as F
from torchvision.utils import save_image
num_classes = enc.vocab_size
batch_size = 1024
for i in range(num_classes // batch_size):
    ind = torch.arange(i * batch_size, (i+1) * batch_size)
    z = F.one_hot(ind, num_classes=num_classes).unsqueeze(-1).unsqueeze(-1).float()
    x_stats = dec(z.cuda()).float()
    x_recs = unmap_pixels(torch.sigmoid(x_stats[:, :3]))
    save_image(x_recs, "viz/dic{}.png".format(i), nrow=32)

@sczhou The default parameter in GumbelVQ's model.yaml seems okay to me. That may be a good starting point.

Thanks for your share

Zyriix commented 1 year ago

hey guys, if you are still interested about optimize codebooks. i tried using the codebook with projection and l2 norm from https://arxiv.org/pdf/2110.04627v3.pdf, it works well. here is a codebook. image it has great color and various shapes.

function2-llx commented 1 year ago

@TomoshibiAkira @crowsonkb Thanks for sharing your insight, I am working on it recently and will update if concrete conclusion is reached.

@fnzhan Congratulations on your article being accepted by CVPR 2023! Would you kindly share your codes and pre-trained weights? It would help us to better understand and follow up on your work.

OrangeSodahub commented 1 year ago

Hi, guys, I have tried train another VQModel (first stage) on my own datasets, (modified the encoder, decoder a little), however when training, the vector quantization loss rises, and kl loss also rises, any suggestions?