CompVis / taming-transformers

Taming Transformers for High-Resolution Image Synthesis
https://arxiv.org/abs/2012.09841
MIT License
5.69k stars 1.13k forks source link

Very confused by the discriminator loss #93

Closed xesdiny closed 3 years ago

xesdiny commented 3 years ago

When training the VQGAN pipeline in FFHQ dataset. I checked the disc_loss use the function like vanilla_d_loss

def hinge_d_loss(logits_real, logits_fake):
    loss_real = torch.mean(F.relu(1. - logits_real))
    loss_fake = torch.mean(F.relu(1. + logits_fake))
    d_loss = 0.5 * (loss_real + loss_fake)
    return d_loss

But the metric in tensorboard ,the loss is very strangeness! image

I am confused whether this discriminator loss is really optimized for generator training.

The discriminator loss is joined to the process after the training step reaches 30K. By the way, add the metric of discriminator loss form training starts to the shown in the picture above. image

hyakuchiki commented 3 years ago

A lot of people seems to have the same problem with the discriminator not being trained properly. https://github.com/CompVis/taming-transformers/issues/73 Have you looked at the d_weight value on Tensorboard? If it is fluctuating at high values then it might be a problem. I suspect that if the disc_start parameter is higher, the reconstruction will settle first and the d_weight will be a sensible value. The authors suggest that you train 3-5 epochs without the discriminator in case of ImageNet, so that would mean that disc_start should be several millions? I guess that the discriminator should only be used when the VQVAE is starting to produce alright results. https://github.com/CompVis/taming-transformers/issues/31 The default value for disc_start is 10000 in custom_vqgan.yaml, which seems way too low. I had the same problem, so, I set disc_start to 50000 and disc_weight to 0.2 and I'm getting somewhat better results (Although I'm worried that disc_weight is a bit too low now?).

image image
xesdiny commented 3 years ago

Emm Yeah! I understand what you mean is that the discriminator is invalid before the generator reaches the nice benchmark, so the time when the discriminator enters the training phase should be delayed. The d_weight fraction is used as the weight coefficient of the discriminator to weight the total_loss. And It It calculates the 2-norm ratio after deriving the parameters of the last layer of the model based on rec_loss and g_loss.

    def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
        if last_layer is not None:
            nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
            g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
        else:
            nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
            g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]

        d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
        d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
        d_weight = d_weight * self.discriminator_weight
        return d_weight

The d_weight_step value in yours tensorboard approaching zeros. And I think this value should be stable at about 1 to guide the generation of the generator.(But in fact, when the value was floating around 1, disc_loss was not decreased.)Maybe I did't understand the meaning behind d_weight correctly. Emm .. I will adopt your suggestions on this pipeline. Thx~

A lot of people seems to have the same problem with the discriminator not being trained properly.

73

Have you looked at the d_weight value on Tensorboard? If it is fluctuating at high values then it might be a problem. I suspect that if the disc_start parameter is higher, the reconstruction will settle first and the d_weight will be a sensible value. The authors suggest that you train 3-5 epochs without the discriminator in case of ImageNet, so that would mean that disc_start should be several millions? I guess that the discriminator should only be used when the VQVAE is starting to produce alright results.

31

The default value for disc_start is 10000 in custom_vqgan.yaml, which seems way too low. I had the same problem, so, I set disc_start to 50000 and disc_weight to 0.2 and I'm getting somewhat better results (Although I'm worried that disc_weight is a bit too low now?).

image image
fortunechen commented 3 years ago

Hi, How is your results now? Could you please share your learning from tuning the disc_start and disc_weight?

Thx

MaxyLee commented 2 years ago

Succeed to get a good result on CUB dataset by setting disc_start=50,000 and disc_weight=0.2: Original images: media_images_train_inputs_22708_b12a2d1c48148354bc98 Reconstructed images: media_images_train_reconstructions_22708_016c7934338b71486a37

PanXiebit commented 2 years ago

@MaxyLee congratulations! could you show more setting details? how many examples of your CUB dataset, and how many steps are in one epoch? Exactly, how many epochs do you start the discriminator?

MaxyLee commented 2 years ago

@MaxyLee congratulations! could you show more setting details? how many examples of your CUB dataset, and how many steps are in one epoch? Exactly, how many epochs do you start the discriminator?

Here is my config:

model:
  base_learning_rate: 4.5e-6
  target: taming.models.vqgan.VQModel
  params:
    embed_dim: 256
    n_embed: 1024
    ddconfig:
      double_z: False
      z_channels: 256
      resolution: 256
      in_channels: 3
      out_ch: 3
      ch: 128
      ch_mult: [ 1,1,2,2,4]  # num_down = len(ch_mult)-1
      num_res_blocks: 2
      attn_resolutions: [16]
      dropout: 0.0

    lossconfig:
      target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
      params:
        disc_conditional: False
        disc_in_channels: 3
        disc_start: 50000
        disc_weight: 0.2
        codebook_weight: 1.0

data:
  target: main.DataModuleFromConfig
  params:
    batch_size: 5
    num_workers: 8
    train:
      target: taming.data.custom.CustomTrain
      params:
        training_images_list_file: /data/share/data/birds/CUB_200_2011/cub_train.txt
        size: 256
    validation:
      target: taming.data.custom.CustomTest
      params:
        test_images_list_file: /data/share/data/birds/CUB_200_2011/cub_test.txt
        size: 256

I trained this model on CUB train split(8,855 images) using 4 GPUs with approximately 400 steps per epoch. The discriminator therefore started at more than 100 epochs. Hope it will help

PanXiebit commented 2 years ago

@MaxyLee thank u very much!!!

PanXiebit commented 2 years ago

Hi @MaxyLee, I have trained the vqgan with your setting on my own dataset, the discriminator startes at about 100 epochs, and disc_weight is 0.2. However I still faced the problem, the generated quality was alright. But after starting discriminator, it became worse. This is my training curve.

wx wx2

In fact the generated images are alright without discriminator. In your traning process, do your generated images become much better after gan training?

MaxyLee commented 2 years ago

Hi @MaxyLee, I have trained the vqgan with your setting on my own dataset, the discriminator startes at about 100 epochs, and disc_weight is 0.2. However I still faced the problem, the generated quality was alright. But after starting discriminator, it became worse. This is my training curve.

wx wx2

In fact the generated images are alright without discriminator. In your traning process, do your generated images become much better after gan training?

Yes, my model performed much better when the discriminator loss was introduced. As shown in the figure, my model could not generate fine-grained images without the discriminator. media_images_train_reconstructions_15674_181fa5fefa62fa9ffef6 Maybe you can try to train the generator longer before adding d loss and select the best checkpoint. Below are my training curves:

Screen Shot 2021-10-19 at 2 25 04 PM
PanXiebit commented 2 years ago

@MaxyLee thank you for your patience and kindness! I will try more experiments.

kaihe commented 2 years ago

I think for a successful discriminator training, logits fake should be negative and logits real should be positive. But I noticed that in the abrove train curves, logits fake and logits real looks always same. Does that mean discriminator is failed and just output same value regardless of input image? @MaxyLee would you also share your training curves of logits? image

MaxyLee commented 2 years ago

I think for a successful discriminator training, logits fake should be negative and logits real should be positive. But I noticed that in the abrove train curves, logits fake and logits real looks always same. Does that mean discriminator is failed and just output same value regardless of input image? @MaxyLee would you also share your training curves of logits? image

These are my training curves: W B Chart 1_13_2022, 10_54_39 PM W B Chart 1_13_2022, 10_55_18 PM

kaihe commented 2 years ago

I think for a successful discriminator training, logits fake should be negative and logits real should be positive. But I noticed that in the abrove train curves, logits fake and logits real looks always same. Does that mean discriminator is failed and just output same value regardless of input image? @MaxyLee would you also share your training curves of logits? image

These are my training curves: W B Chart 1_13_2022, 10_54_39 PM W B Chart 1_13_2022, 10_55_18 PM

Thanks very much, that confirm my suspicions: a good discriminator is enough for sharp images, no need for gan equilibrium

ThisisBillhe commented 2 months ago

I think for a successful discriminator training, logits fake should be negative and logits real should be positive. But I noticed that in the abrove train curves, logits fake and logits real looks always same. Does that mean discriminator is failed and just output same value regardless of input image? @MaxyLee would you also share your training curves of logits? image

Hi, How to solve the problem of logits_real and logits_fake being almost the same?