eriklindernoren / Keras-GAN

Keras implementations of Generative Adversarial Networks.
MIT License
9.18k stars 3.14k forks source link

Adding conditional to WGAN-GP #53

Open ortix opened 6 years ago

ortix commented 6 years ago

I am extending the WGAN-GP to be conditional. I am concatenating a label input to both the discriminator input and noise input for the generator.

However, I am getting stuck in the final part where I build the combined model.

        # The generator takes noise and the target label (states) as input
        # and generates the corresponding samples of that label
        noise = Input(shape=(self.latent_size, ), name="noise")
        label = Input(shape=(self.label_size, ), name="labels")
        real_samples = Input(shape=(self.input_size,), name="real")

        self.discriminator = self.build_discriminator()
        self.generator = self.build_generator([noise, label])

        # First we train the discriminator
        self.generator.trainable = False
        fake_samples = self.generator([noise, label])

        fake = self.discriminator([fake_samples, label])
        valid = self.discriminator([real_samples, label])

        interpolated = Lambda(self.random_weighted_average)([real_samples, fake_samples])
        valid_interp = self.discriminator([interpolated, label])

        # The combined model  (stacked generator and discriminator)
        # Trains generator to fool discriminator
        self.d_model = Model([real_samples, noise, label],
                             [valid, fake, valid_interp],
                             name="discriminator")
        # Time to train the generator
        self.discriminator.trainable = False
        self.generator.trainable = True

        noise_gen = Input(shape=(self.latent_size,), name="noise_gen")

        fake_samples = self.generator([noise_gen, label])
        valid = self.discriminator([fake_samples, label])

        self.g_model = Model([noise_gen, label], valid, name="generator")
        self.g_model.compile(loss=self.wasserstein_loss, optimizer=optimizer)

I don't think this is the way to create the final model. How would I create the combined model to also include the label the right way? I'm assuming that the noise should actually be the generated output of the generator? Any help?

ParthaEth commented 4 years ago

How do you plan on getting the right conditioning?