eriklindernoren / Keras-GAN

Keras implementations of Generative Adversarial Networks.
MIT License
9.19k stars 3.14k forks source link

wgan_gp #166

Open kristosh opened 5 years ago

kristosh commented 5 years ago

I would like to modify the critic of the Network with the purpose of not only working as a discriminator but also as an auxiliar classifier. Therefore, I would like the last layer of the critic to output the Wasserstein distance and multi cross entropy results from the classification. I tried to modify the code accordingly. Therefore, now my init function looks like:

 `def __init__(self):
    self.img_rows = 28
    self.img_cols = 28  
    self.channels = 3
    self.img_shape = (self.img_rows, self.img_cols, self.channels)
    self.latent_dim = 106

    # Following parameter and optimizer set as recommended in paper
    self.n_critic = 5
    optimizer = RMSprop(lr=0.00001)

    # Build the generator and critic
    self.generator = build_generator_iwGANs()
    self.critic = build_critic_iwGANs()

    #-------------------------------
    # Construct Computational Graph
    #       for the Critic
    #-------------------------------

    # Image input (real sample)
    real_img = Input(shape=self.img_shape)

    # Noise input
    z_disc = Input(shape=(self.latent_dim,))

    # Generate image based of noise (fake sample)
    fake_img = self.generator(z_disc)

    # Discriminator determines validity of the real and fake images
    fake, aux1 = self.critic(fake_img)
    valid, aux2= self.critic(real_img)

    # Construct weighted average between real and fake images

    interpolated_img = RandomWeightedAverage()([real_img, fake_img])
    # Determine validity of weighted sample
    validity_interpolated, classified = self.critic(interpolated_img)

    # Use Python partial to provide loss function with additional
    # 'averaged_samples' argument
    partial_gp_loss = partial(self.gradient_penalty_loss,
                      averaged_samples=interpolated_img)
    partial_gp_loss.__name__ = 'gradient_penalty' # Keras requires function names

    #pdb.set_trace()
    self.critic_model = Model(inputs=[real_img, z_disc], outputs=[valid, fake, validity_interpolated, aux1])

    self.critic_model.compile(loss=[self.wasserstein_loss,
                                    self.wasserstein_loss,
                                    partial_gp_loss, 
                                    'sparse_categorical_crossentropy'],
                                    optimizer=optimizer,
                                    loss_weights=[1, 1, 10, 10],
                                    metrics=['accuracy'])

    # For the generator we freeze the critic's layers
    self.critic.trainable = False
    self.generator.trainable = True

    # Sampled noise for input to generator
    z_gen = Input(shape=(self.latent_dim,))
    # Generate images based of noise
    img = self.generator(z_gen)
    # Discriminator determines validity
    generated, classified_emotion = self.critic(img)
    # Defines generator model
    self.generator_model = Model(z_gen, generated)
    #g_loss = self.generator_model.train_on_batch(noise, [valid, batch_lbls])
    self.generator_model.compile(loss=self.wasserstein_loss, optimizer=optimizer)`

The critic model:

`def build_critic_iwGANs():

model = Sequential()

model.add(Conv2D(16, kernel_size=3, strides=2, input_shape=img_shape, padding="same"))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(Conv2D(32, kernel_size=3, strides=2, padding="same"))
model.add(ZeroPadding2D(padding=((0,1),(0,1))))
model.add(BatchNormalization(momentum=0.8))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
model.add(BatchNormalization(momentum=0.8))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(Conv2D(128, kernel_size=3, strides=1, padding="same"))
model.add(BatchNormalization(momentum=0.8))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))

model.add(Conv2D(128, kernel_size=3, strides=1, padding="same"))
model.add(BatchNormalization(momentum=0.8))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))

model.add(Conv2D(128, kernel_size=3, strides=1, padding="same"))
model.add(BatchNormalization(momentum=0.8))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))

model.add(Flatten())

img = Input(shape=img_shape)

features = model(img)

fake = Dense(1, activation='linear', name='generation')(features)
aux = Dense(6, activation='softmax', name='auxiliary')(features)

return Model(img, output = [fake, aux])`

While the training of the critic:

d_loss = self.critic_model.train_on_batch([imgs, noise],[valid, fake, dummy, batch_lbls])

Where batch_lbls are the categorical real annotation of the real distribution. While the system to work and produce extremely nice visual results, the classification performance during the training continues to be dumb and do not learn anything. Is there any obvious bug in my code?

CUITCHENSIYU commented 5 years ago

I think you should addition labels

kristosh commented 5 years ago

You mean where? When I add the input to the generator so the z_gen = Input(shape=(self.latent_dim,)) during the training, I am concatenating some noise with the labels and pass it as input to the critic. Then, when i train the critic I am doing:

d_loss = self.critic_model.train_on_batch([imgs, noise],[valid, fake, dummy, batch_lbls])

imgs are the target images, noise the input to the generator, and batch labels the correspondent labels to the target samples.