Closed xesdiny closed 3 years ago
A lot of people seems to have the same problem with the discriminator not being trained properly. https://github.com/CompVis/taming-transformers/issues/73 Have you looked at the d_weight value on Tensorboard? If it is fluctuating at high values then it might be a problem. I suspect that if the disc_start parameter is higher, the reconstruction will settle first and the d_weight will be a sensible value. The authors suggest that you train 3-5 epochs without the discriminator in case of ImageNet, so that would mean that disc_start should be several millions? I guess that the discriminator should only be used when the VQVAE is starting to produce alright results. https://github.com/CompVis/taming-transformers/issues/31 The default value for disc_start is 10000 in custom_vqgan.yaml, which seems way too low. I had the same problem, so, I set disc_start to 50000 and disc_weight to 0.2 and I'm getting somewhat better results (Although I'm worried that disc_weight is a bit too low now?).
Emm Yeah!
I understand what you mean is that the discriminator is invalid before the generator reaches the nice benchmark, so the time when the discriminator enters the training phase should be delayed.
The d_weight
fraction is used as the weight coefficient of the discriminator to weight the total_loss
.
And It It calculates the 2-norm ratio after deriving the parameters of the last layer of the model based on rec_loss and g_loss.
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
if last_layer is not None:
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
else:
nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
d_weight = d_weight * self.discriminator_weight
return d_weight
The d_weight_step
value in yours tensorboard approaching zeros.
And I think this value should be stable at about 1 to guide the generation of the generator.(But in fact, when the value was floating around 1, disc_loss was not decreased.)Maybe I did't understand the meaning behind d_weight
correctly.
Emm .. I will adopt your suggestions on this pipeline.
Thx~
A lot of people seems to have the same problem with the discriminator not being trained properly.
73
Have you looked at the d_weight value on Tensorboard? If it is fluctuating at high values then it might be a problem. I suspect that if the disc_start parameter is higher, the reconstruction will settle first and the d_weight will be a sensible value. The authors suggest that you train 3-5 epochs without the discriminator in case of ImageNet, so that would mean that disc_start should be several millions? I guess that the discriminator should only be used when the VQVAE is starting to produce alright results.
31
The default value for disc_start is 10000 in custom_vqgan.yaml, which seems way too low. I had the same problem, so, I set disc_start to 50000 and disc_weight to 0.2 and I'm getting somewhat better results (Although I'm worried that disc_weight is a bit too low now?).
Hi, How is your results now? Could you please share your learning from tuning the disc_start
and disc_weight
?
Thx
Succeed to get a good result on CUB dataset by setting disc_start=50,000
and disc_weight=0.2
:
Original images:
Reconstructed images:
@MaxyLee congratulations! could you show more setting details? how many examples of your CUB dataset, and how many steps are in one epoch? Exactly, how many epochs do you start the discriminator?
@MaxyLee congratulations! could you show more setting details? how many examples of your CUB dataset, and how many steps are in one epoch? Exactly, how many epochs do you start the discriminator?
Here is my config:
model:
base_learning_rate: 4.5e-6
target: taming.models.vqgan.VQModel
params:
embed_dim: 256
n_embed: 1024
ddconfig:
double_z: False
z_channels: 256
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1
num_res_blocks: 2
attn_resolutions: [16]
dropout: 0.0
lossconfig:
target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
params:
disc_conditional: False
disc_in_channels: 3
disc_start: 50000
disc_weight: 0.2
codebook_weight: 1.0
data:
target: main.DataModuleFromConfig
params:
batch_size: 5
num_workers: 8
train:
target: taming.data.custom.CustomTrain
params:
training_images_list_file: /data/share/data/birds/CUB_200_2011/cub_train.txt
size: 256
validation:
target: taming.data.custom.CustomTest
params:
test_images_list_file: /data/share/data/birds/CUB_200_2011/cub_test.txt
size: 256
I trained this model on CUB train split(8,855 images) using 4 GPUs with approximately 400 steps per epoch. The discriminator therefore started at more than 100 epochs. Hope it will help
@MaxyLee thank u very much!!!
Hi @MaxyLee, I have trained the vqgan with your setting on my own dataset, the discriminator startes at about 100 epochs, and disc_weight is 0.2. However I still faced the problem, the generated quality was alright. But after starting discriminator, it became worse. This is my training curve.
In fact the generated images are alright without discriminator. In your traning process, do your generated images become much better after gan training?
Hi @MaxyLee, I have trained the vqgan with your setting on my own dataset, the discriminator startes at about 100 epochs, and disc_weight is 0.2. However I still faced the problem, the generated quality was alright. But after starting discriminator, it became worse. This is my training curve.
In fact the generated images are alright without discriminator. In your traning process, do your generated images become much better after gan training?
Yes, my model performed much better when the discriminator loss was introduced. As shown in the figure, my model could not generate fine-grained images without the discriminator. Maybe you can try to train the generator longer before adding d loss and select the best checkpoint. Below are my training curves:
@MaxyLee thank you for your patience and kindness! I will try more experiments.
I think for a successful discriminator training, logits fake should be negative and logits real should be positive. But I noticed that in the abrove train curves, logits fake and logits real looks always same. Does that mean discriminator is failed and just output same value regardless of input image? @MaxyLee would you also share your training curves of logits?
I think for a successful discriminator training, logits fake should be negative and logits real should be positive. But I noticed that in the abrove train curves, logits fake and logits real looks always same. Does that mean discriminator is failed and just output same value regardless of input image? @MaxyLee would you also share your training curves of logits?
These are my training curves:
I think for a successful discriminator training, logits fake should be negative and logits real should be positive. But I noticed that in the abrove train curves, logits fake and logits real looks always same. Does that mean discriminator is failed and just output same value regardless of input image? @MaxyLee would you also share your training curves of logits?
These are my training curves:
Thanks very much, that confirm my suspicions: a good discriminator is enough for sharp images, no need for gan equilibrium
I think for a successful discriminator training, logits fake should be negative and logits real should be positive. But I noticed that in the abrove train curves, logits fake and logits real looks always same. Does that mean discriminator is failed and just output same value regardless of input image? @MaxyLee would you also share your training curves of logits?
Hi, How to solve the problem of logits_real and logits_fake being almost the same?
When training the VQGAN pipeline in FFHQ dataset. I checked the
disc_loss
use the function likevanilla_d_loss
But the metric in tensorboard ,the loss is very strangeness!
I am confused whether this discriminator loss is really optimized for generator training.
The discriminator loss is joined to the process after the training step reaches 30K. By the way, add the metric of discriminator loss form training starts to the shown in the picture above.