hanzhanggit / StackGAN

MIT License
1.86k stars 457 forks source link

StageI Training Failure #32

Open Lotayou opened 6 years ago

Lotayou commented 6 years ago

I'm currently trying to train the network using pre-trained word embedding of bird and flowers. I'm using TensorFlow 1.3 so I did some necessary adaptions. But the result is frustrating: train It seems the generator is a completely mess. Also I checked some loss values and it's really weird that some of them are never changed:

default

I'm not sure what could be the cause of the problem, is it possible that some TensorFlow core function behave differently in r1.3 (as opposed to r0.12)?

Is there anyone who successfully reproduced the result with new version of TensorFlow? Can you send me a message to geekyang123@gmail.com so I can check some implementation issue with you? Much appreciated guys!

Ereebay commented 6 years ago

I met the similar problem before, and I modify the condition augmentation function, and it works.

huanchen813 commented 6 years ago

@Ereebay I meet the same problem, could you tell me how to modify the condition augmentation function?

NepTuNew commented 6 years ago

@Ereebay @chenhuangsong I meet the same problem, and can't find the solution. Could you help me, or send message to anderson08121995@gmail.com. Thanks a lot.

Tristan-YF commented 5 years ago

@Ereebay I meet the same problem, could you please tell me how to modify the condition augmentation function ?

Ereebay commented 5 years ago

It's not the condition augementation. In fact, I modified the code without pretty tensor. And the reason why this happened I guess was that I didn't save the parameters of moving mean in the batch normalization.

Get Outlook for And


From: TristanGao notifications@github.com Sent: Tuesday, January 15, 2019 9:29:12 PM To: hanzhanggit/StackGAN Cc: EreeBay; Mention Subject: Re: [hanzhanggit/StackGAN] StageI Training Failure (#32)

@Ereebayhttps://github.com/Ereebay I meet the same problem, could you please tell me how to modify the condition augmentation function ?

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHubhttps://github.com/hanzhanggit/StackGAN/issues/32#issuecomment-454391897, or mute the threadhttps://github.com/notifications/unsubscribe-auth/AQeIez-82CSM3YYdj_zcn0WLRF7v_1V6ks5vDdeogaJpZM4RUIwZ.

NepTuNew commented 5 years ago

Thanks for your reply! I figured out the problem, that I implemented the code from scratch with tensorflow. So the problem I met was like Ereebay said that the parameters of moving mean in the batch normalization. I can confirm that the problem is in the batch normalization section.

If you just use the code of StackGAN repository directly, I think that will not meet this problem!

akhilvasvani commented 5 years ago

I tried to implement this as well. I used

def conv_batch_normalization(inputs, name, epsilon=1e-5, in_dim=None, is_training=True, activation_fn=None, reuse=None):
    return tf.contrib.layers.batch_norm(inputs, decay=0.9, center=True, scale=True, epsilon=epsilon,
                                        activation_fn=activation_fn,
                                        param_initializers={'beta': tf.constant_initializer(0.),
                                                            'gamma': tf.random_normal_initializer(1., 0.02)},
                                        reuse=reuse, is_training=is_training, scope=name) 

But this did not seem to solve the problem. The generator still produces weird results. Did you do guys do something similar to this?

Ereebay commented 5 years ago

You need to update the moving mean before the optimizer.

def prepare_trainer(self, generator_loss, discriminator_loss):
    '''Helper function for init_opt'''
    all_vars = tf.trainable_variables()

    g_vars = [var for var in all_vars if
              var.name.startswith('g_')]
    d_vars = [var for var in all_vars if
              var.name.startswith('d_')]

    update_ops_D = [var for var in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if var.name.startswith('d_')]
    update_ops_G = [var for var in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if var.name.startswith('g_')]
    with tf.control_dependencies(update_ops_G):
        generator_opt = tf.train.AdamOptimizer(self.generator_lr,
                                               beta1=0.5)
        self.generator_trainer = generator_opt.minimize(generator_loss, var_list=g_vars)
    with tf.control_dependencies(update_ops_D):
        discriminator_opt = tf.train.AdamOptimizer(self.discriminator_lr,
                                                   beta1=0.5)
        self.discriminator_trainer =discriminator_opt.minimize(discriminator_loss, var_list=d_vars)
    self.log_vars.append(("g_learning_rate", self.generator_lr))
    self.log_vars.append(("d_learning_rate", self.discriminator_lr))

You also need to save the global variables or save the trainable and moving mean and var of the batch normalization. You will use them to restore the stage I model to generate samples for stage II.

akhilvasvani commented 5 years ago

Thanks @Ereebay. That worked!

According to the TensorFlow documentation, tf.contrib.layers.batch_norm has the option trainable that has a default argument set to True, which automatically adds variables (moving mean and moving variance) to the graph collection.

I'm also assuming I need to update my prepare_trainer function in trainer for Stage II as well following what you posted above?

Ereebay commented 5 years ago

I'm not very sure whether it will add moving_mean into the trainable variable list. I recommend you make a variable list manually. And according to the docs, you need that code to update moving mean. Anyway, this code has been outdated for a long time, I recommend to implement it by the latest version of TF or Pytorch from scratch.

akhilvasvani commented 5 years ago

I checked and it was added to the global variables.

('d_net/d_embedd/fc/biases', [128])
('d_net/d_embedd/fc/weights', [1024, 128])
('d_net/d_n1.0/batch_norm/beta', [128])
('d_net/d_n1.0/batch_norm/gamma', [128])
('d_net/d_n1.0/batch_norm/moving_mean', [128])
('d_net/d_n1.0/batch_norm/moving_variance', [128])
('d_net/d_n1.0/batch_norm2/beta', [256])
('d_net/d_n1.0/batch_norm2/gamma', [256])
('d_net/d_n1.0/batch_norm2/moving_mean', [256])
('d_net/d_n1.0/batch_norm2/moving_variance', [256])
('d_net/d_n1.0/batch_norm3/beta', [512])
('d_net/d_n1.0/batch_norm3/gamma', [512])
('d_net/d_n1.0/batch_norm3/moving_mean', [512])
('d_net/d_n1.0/batch_norm3/moving_variance', [512])
('d_net/d_n1.0/conv2d/weights', [4, 4, 3, 64])
('d_net/d_n1.0/conv2d2/weights', [4, 4, 64, 128])
('d_net/d_n1.0/conv2d3/weights', [4, 4, 128, 256])
('d_net/d_n1.0/conv2d4/weights', [4, 4, 256, 512])
('d_net/d_n1.1/batch_norm/beta', [128])
('d_net/d_n1.1/batch_norm/gamma', [128])
('d_net/d_n1.1/batch_norm/moving_mean', [128])
('d_net/d_n1.1/batch_norm/moving_variance', [128])
('d_net/d_n1.1/batch_norm2/beta', [128])
('d_net/d_n1.1/batch_norm2/gamma', [128])
('d_net/d_n1.1/batch_norm2/moving_mean', [128])
('d_net/d_n1.1/batch_norm2/moving_variance', [128])
('d_net/d_n1.1/batch_norm3/beta', [512])
('d_net/d_n1.1/batch_norm3/gamma', [512])
('d_net/d_n1.1/batch_norm3/moving_mean', [512])
('d_net/d_n1.1/batch_norm3/moving_variance', [512])
('d_net/d_n1.1/conv2d/weights', [1, 1, 512, 128])
('d_net/d_n1.1/conv2d2/weights', [3, 3, 128, 128])
('d_net/d_n1.1/conv2d3/weights', [3, 3, 128, 512])
('d_net/d_net/d_embedd/fc/biases/Adam', [128])
('d_net/d_net/d_embedd/fc/biases/Adam_1', [128])
('d_net/d_net/d_embedd/fc/weights/Adam', [1024, 128])
('d_net/d_net/d_embedd/fc/weights/Adam_1', [1024, 128])
('d_net/d_net/d_n1.0/batch_norm/beta/Adam', [128])
('d_net/d_net/d_n1.0/batch_norm/beta/Adam_1', [128])
('d_net/d_net/d_n1.0/batch_norm/gamma/Adam', [128])
('d_net/d_net/d_n1.0/batch_norm/gamma/Adam_1', [128])
('d_net/d_net/d_n1.0/batch_norm2/beta/Adam', [256])
('d_net/d_net/d_n1.0/batch_norm2/beta/Adam_1', [256])
('d_net/d_net/d_n1.0/batch_norm2/gamma/Adam', [256])
('d_net/d_net/d_n1.0/batch_norm2/gamma/Adam_1', [256])
('d_net/d_net/d_n1.0/batch_norm3/beta/Adam', [512])
('d_net/d_net/d_n1.0/batch_norm3/beta/Adam_1', [512])
('d_net/d_net/d_n1.0/batch_norm3/gamma/Adam', [512])
('d_net/d_net/d_n1.0/batch_norm3/gamma/Adam_1', [512])
('d_net/d_net/d_n1.0/conv2d/weights/Adam', [4, 4, 3, 64])
('d_net/d_net/d_n1.0/conv2d/weights/Adam_1', [4, 4, 3, 64])
('d_net/d_net/d_n1.0/conv2d2/weights/Adam', [4, 4, 64, 128])
('d_net/d_net/d_n1.0/conv2d2/weights/Adam_1', [4, 4, 64, 128])
('d_net/d_net/d_n1.0/conv2d3/weights/Adam', [4, 4, 128, 256])
('d_net/d_net/d_n1.0/conv2d3/weights/Adam_1', [4, 4, 128, 256])
('d_net/d_net/d_n1.0/conv2d4/weights/Adam', [4, 4, 256, 512])
('d_net/d_net/d_n1.0/conv2d4/weights/Adam_1', [4, 4, 256, 512])
('d_net/d_net/d_n1.1/batch_norm/beta/Adam', [128])
('d_net/d_net/d_n1.1/batch_norm/beta/Adam_1', [128])
('d_net/d_net/d_n1.1/batch_norm/gamma/Adam', [128])
('d_net/d_net/d_n1.1/batch_norm/gamma/Adam_1', [128])
('d_net/d_net/d_n1.1/batch_norm2/beta/Adam', [128])
('d_net/d_net/d_n1.1/batch_norm2/beta/Adam_1', [128])
('d_net/d_net/d_n1.1/batch_norm2/gamma/Adam', [128])
('d_net/d_net/d_n1.1/batch_norm2/gamma/Adam_1', [128])
('d_net/d_net/d_n1.1/batch_norm3/beta/Adam', [512])
('d_net/d_net/d_n1.1/batch_norm3/beta/Adam_1', [512])
('d_net/d_net/d_n1.1/batch_norm3/gamma/Adam', [512])
('d_net/d_net/d_n1.1/batch_norm3/gamma/Adam_1', [512])
('d_net/d_net/d_n1.1/conv2d/weights/Adam', [1, 1, 512, 128])
('d_net/d_net/d_n1.1/conv2d/weights/Adam_1', [1, 1, 512, 128])
('d_net/d_net/d_n1.1/conv2d2/weights/Adam', [3, 3, 128, 128])
('d_net/d_net/d_n1.1/conv2d2/weights/Adam_1', [3, 3, 128, 128])
('d_net/d_net/d_n1.1/conv2d3/weights/Adam', [3, 3, 128, 512])
('d_net/d_net/d_n1.1/conv2d3/weights/Adam_1', [3, 3, 128, 512])
('d_net/d_net/d_template/batch_norm/beta/Adam', [512])
('d_net/d_net/d_template/batch_norm/beta/Adam_1', [512])
('d_net/d_net/d_template/batch_norm/gamma/Adam', [512])
('d_net/d_net/d_template/batch_norm/gamma/Adam_1', [512])
('d_net/d_net/d_template/conv2d/weights/Adam', [1, 1, 640, 512])
('d_net/d_net/d_template/conv2d/weights/Adam_1', [1, 1, 640, 512])
('d_net/d_net/d_template/conv2d2/weights/Adam', [4, 4, 512, 1])
('d_net/d_net/d_template/conv2d2/weights/Adam_1', [4, 4, 512, 1])
('d_net/d_template/batch_norm/beta', [512])
('d_net/d_template/batch_norm/gamma', [512])
('d_net/d_template/batch_norm/moving_mean', [512])
('d_net/d_template/batch_norm/moving_variance', [512])
('d_net/d_template/conv2d/weights', [1, 1, 640, 512])
('d_net/d_template/conv2d2/weights', [4, 4, 512, 1])
('d_net/g_net/g_OT/batch_norm/beta/Adam', [256])
('d_net/g_net/g_OT/batch_norm/beta/Adam_1', [256])
('d_net/g_net/g_OT/batch_norm/gamma/Adam', [256])
('d_net/g_net/g_OT/batch_norm/gamma/Adam_1', [256])
('d_net/g_net/g_OT/batch_norm2/beta/Adam', [128])
('d_net/g_net/g_OT/batch_norm2/beta/Adam_1', [128])
('d_net/g_net/g_OT/batch_norm2/gamma/Adam', [128])
('d_net/g_net/g_OT/batch_norm2/gamma/Adam_1', [128])
('d_net/g_net/g_OT/conv2d/weights/Adam', [3, 3, 512, 256])
('d_net/g_net/g_OT/conv2d/weights/Adam_1', [3, 3, 512, 256])
('d_net/g_net/g_OT/conv2d2/weights/Adam', [3, 3, 256, 128])
('d_net/g_net/g_OT/conv2d2/weights/Adam_1', [3, 3, 256, 128])
('d_net/g_net/g_OT/conv2d3/weights/Adam', [3, 3, 128, 3])
('d_net/g_net/g_OT/conv2d3/weights/Adam_1', [3, 3, 128, 3])
('d_net/g_net/g_n1.0/batch_norm/beta/Adam', [16384])
('d_net/g_net/g_n1.0/batch_norm/beta/Adam_1', [16384])
('d_net/g_net/g_n1.0/batch_norm/gamma/Adam', [16384])
('d_net/g_net/g_n1.0/batch_norm/gamma/Adam_1', [16384])
('d_net/g_net/g_n1.0/fc/biases/Adam', [16384])
('d_net/g_net/g_n1.0/fc/biases/Adam_1', [16384])
('d_net/g_net/g_n1.0/fc/weights/Adam', [228, 16384])
('d_net/g_net/g_n1.0/fc/weights/Adam_1', [228, 16384])
('d_net/g_net/g_n1.1/batch_norm_1/beta/Adam', [256])
('d_net/g_net/g_n1.1/batch_norm_1/beta/Adam_1', [256])
('d_net/g_net/g_n1.1/batch_norm_1/gamma/Adam', [256])
('d_net/g_net/g_n1.1/batch_norm_1/gamma/Adam_1', [256])
('d_net/g_net/g_n1.1/batch_norm_2/beta/Adam', [256])
('d_net/g_net/g_n1.1/batch_norm_2/beta/Adam_1', [256])
('d_net/g_net/g_n1.1/batch_norm_2/gamma/Adam', [256])
('d_net/g_net/g_n1.1/batch_norm_2/gamma/Adam_1', [256])
('d_net/g_net/g_n1.1/batch_norm_3/beta/Adam', [1024])
('d_net/g_net/g_n1.1/batch_norm_3/beta/Adam_1', [1024])
('d_net/g_net/g_n1.1/batch_norm_3/gamma/Adam', [1024])
('d_net/g_net/g_n1.1/batch_norm_3/gamma/Adam_1', [1024])
('d_net/g_net/g_n1.1/conv2d/weights/Adam', [1, 1, 1024, 256])
('d_net/g_net/g_n1.1/conv2d/weights/Adam_1', [1, 1, 1024, 256])
('d_net/g_net/g_n1.1/conv2d2/weights/Adam', [3, 3, 256, 256])
('d_net/g_net/g_n1.1/conv2d2/weights/Adam_1', [3, 3, 256, 256])
('d_net/g_net/g_n1.1/conv2d3/weights/Adam', [3, 3, 256, 1024])
('d_net/g_net/g_n1.1/conv2d3/weights/Adam_1', [3, 3, 256, 1024])
('d_net/g_net/g_n2.0/batch_norm/beta/Adam', [512])
('d_net/g_net/g_n2.0/batch_norm/beta/Adam_1', [512])
('d_net/g_net/g_n2.0/batch_norm/gamma/Adam', [512])
('d_net/g_net/g_n2.0/batch_norm/gamma/Adam_1', [512])
('d_net/g_net/g_n2.0/conv2d/weights/Adam', [3, 3, 1024, 512])
('d_net/g_net/g_n2.0/conv2d/weights/Adam_1', [3, 3, 1024, 512])
('d_net/g_net/g_n2.1/batch_norm/beta/Adam', [128])
('d_net/g_net/g_n2.1/batch_norm/beta/Adam_1', [128])
('d_net/g_net/g_n2.1/batch_norm/gamma/Adam', [128])
('d_net/g_net/g_n2.1/batch_norm/gamma/Adam_1', [128])
('d_net/g_net/g_n2.1/batch_norm2/beta/Adam', [128])
('d_net/g_net/g_n2.1/batch_norm2/beta/Adam_1', [128])
('d_net/g_net/g_n2.1/batch_norm2/gamma/Adam', [128])
('d_net/g_net/g_n2.1/batch_norm2/gamma/Adam_1', [128])
('d_net/g_net/g_n2.1/batch_norm3/beta/Adam', [512])
('d_net/g_net/g_n2.1/batch_norm3/beta/Adam_1', [512])
('d_net/g_net/g_n2.1/batch_norm3/gamma/Adam', [512])
('d_net/g_net/g_n2.1/batch_norm3/gamma/Adam_1', [512])
('d_net/g_net/g_n2.1/conv2d/weights/Adam', [1, 1, 512, 128])
('d_net/g_net/g_n2.1/conv2d/weights/Adam_1', [1, 1, 512, 128])
('d_net/g_net/g_n2.1/conv2d2/weights/Adam', [3, 3, 128, 128])
('d_net/g_net/g_n2.1/conv2d2/weights/Adam_1', [3, 3, 128, 128])
('d_net/g_net/g_n2.1/conv2d3/weights/Adam', [3, 3, 128, 512])
('d_net/g_net/g_n2.1/conv2d3/weights/Adam_1', [3, 3, 128, 512])
('d_net/g_net/gen_cond/fc/biases/Adam', [256])
('d_net/g_net/gen_cond/fc/biases/Adam_1', [256])
('d_net/g_net/gen_cond/fc/weights/Adam', [1024, 256])
('d_net/g_net/gen_cond/fc/weights/Adam_1', [1024, 256])
('g_net/g_OT/batch_norm/beta', [256])
('g_net/g_OT/batch_norm/gamma', [256])
('g_net/g_OT/batch_norm/moving_mean', [256])
('g_net/g_OT/batch_norm/moving_variance', [256])
('g_net/g_OT/batch_norm2/beta', [128])
('g_net/g_OT/batch_norm2/gamma', [128])
('g_net/g_OT/batch_norm2/moving_mean', [128])
('g_net/g_OT/batch_norm2/moving_variance', [128])
('g_net/g_OT/conv2d/weights', [3, 3, 512, 256])
('g_net/g_OT/conv2d2/weights', [3, 3, 256, 128])
('g_net/g_OT/conv2d3/weights', [3, 3, 128, 3])
('g_net/g_n1.0/batch_norm/beta', [16384])
('g_net/g_n1.0/batch_norm/gamma', [16384])
('g_net/g_n1.0/batch_norm/moving_mean', [16384])
('g_net/g_n1.0/batch_norm/moving_variance', [16384])
('g_net/g_n1.0/fc/biases', [16384])
('g_net/g_n1.0/fc/weights', [228, 16384])
('g_net/g_n1.1/batch_norm_1/beta', [256])
('g_net/g_n1.1/batch_norm_1/gamma', [256])
('g_net/g_n1.1/batch_norm_1/moving_mean', [256])
('g_net/g_n1.1/batch_norm_1/moving_variance', [256])
('g_net/g_n1.1/batch_norm_2/beta', [256])
('g_net/g_n1.1/batch_norm_2/gamma', [256])
('g_net/g_n1.1/batch_norm_2/moving_mean', [256])
('g_net/g_n1.1/batch_norm_2/moving_variance', [256])
('g_net/g_n1.1/batch_norm_3/beta', [1024])
('g_net/g_n1.1/batch_norm_3/gamma', [1024])
('g_net/g_n1.1/batch_norm_3/moving_mean', [1024])
('g_net/g_n1.1/batch_norm_3/moving_variance', [1024])
('g_net/g_n1.1/conv2d/weights', [1, 1, 1024, 256])
('g_net/g_n1.1/conv2d2/weights', [3, 3, 256, 256])
('g_net/g_n1.1/conv2d3/weights', [3, 3, 256, 1024])
('g_net/g_n2.0/batch_norm/beta', [512])
('g_net/g_n2.0/batch_norm/gamma', [512])
('g_net/g_n2.0/batch_norm/moving_mean', [512])
('g_net/g_n2.0/batch_norm/moving_variance', [512])
('g_net/g_n2.0/conv2d/weights', [3, 3, 1024, 512])
('g_net/g_n2.1/batch_norm/beta', [128])
('g_net/g_n2.1/batch_norm/gamma', [128])
('g_net/g_n2.1/batch_norm/moving_mean', [128])
('g_net/g_n2.1/batch_norm/moving_variance', [128])
('g_net/g_n2.1/batch_norm2/beta', [128])
('g_net/g_n2.1/batch_norm2/gamma', [128])
('g_net/g_n2.1/batch_norm2/moving_mean', [128])
('g_net/g_n2.1/batch_norm2/moving_variance', [128])
('g_net/g_n2.1/batch_norm3/beta', [512])
('g_net/g_n2.1/batch_norm3/gamma', [512])
('g_net/g_n2.1/batch_norm3/moving_mean', [512])
('g_net/g_n2.1/batch_norm3/moving_variance', [512])
('g_net/g_n2.1/conv2d/weights', [1, 1, 512, 128])
('g_net/g_n2.1/conv2d2/weights', [3, 3, 128, 128])
('g_net/g_n2.1/conv2d3/weights', [3, 3, 128, 512])
('g_net/gen_cond/fc/biases', [256])
('g_net/gen_cond/fc/weights', [1024, 256])

Appreciate the help. Thank you