bstriner / keras-adversarial

Keras Generative Adversarial Networks
MIT License
867 stars 231 forks source link

Is there a way to incorporate contextual loss in the generator model? #10

Open kevinwu23 opened 7 years ago

kevinwu23 commented 7 years ago

How would you incorporate an input other than a random seed in the generator function? I am trying to recreate image completion with GANs for Keras as described in the following: https://bamos.github.io/2016/08/09/deep-completion/

Please let me know if there's any clarification needed. Thanks in advance for your time.

bstriner commented 7 years ago

All you should need is an extra input that is provided to both the generator and discriminator. The generator inputs are the latent sample (random seed) and the context and the output is the generated sample. The discriminator inputs are the context and the sample and the output is the real/fake prediction.

Just make sure you keep track of what order your inputs are when you build the model. Please let me know if you run into any issues along the way.

Look at build_gan in adversarial_utils for a simple GAN.

yfake = Activation("linear", name="yfake")(discriminator(generator(generator.inputs)))
yreal = Activation("linear", name="yreal")(discriminator(discriminator.inputs))
model = Model(generator.inputs + discriminator.inputs, [yfake, yreal], name=name)

If context is Y and latent is Z and real data is X, you would do something like

yfake =  Activation("linear", name="yfake")(discriminator([Y, generator([Y,Z])]))
yreal = Activation("linear", name="yreal")(discriminator([Y, X]))
model =  Model([X,Y,Z], [yfake, yreal], name=name)

The Activations are just there to fix the names but not strictly necessary.

Cheers

bstriner commented 7 years ago

Basically, whatever number of inputs you make for your internal keras model will be the number and order of inputs of the keras-adversarial model that wraps it.