bstriner / keras-adversarial

Keras Generative Adversarial Networks
MIT License
867 stars 231 forks source link

Add option to pass compile_kwargs for each player separately #39

Closed d4nst closed 7 years ago

d4nst commented 7 years ago

Adding this option would be useful to implement models where the discriminator and generator losses have different weights.

For example, the BEGAN loss function:

image

could be implemented by passing a symbolic variable kt as a loss weight for the discriminator that could be updated in a callback:

    kt = K.variable(0.0)
    model.adversarial_compile(adversarial_optimizer=AdversarialOptimizerSimultaneous(),
                              player_optimizers=[Adam(1e-4), Adam(1e-4)],
                              player_compile_kwargs=[{}, {'loss_weights': {'yfake': kt}}],
                              loss=custom_loss)
bstriner commented 7 years ago

Great idea, and if someone wants the old functionality they can just duplicate the kwargs several times. Needs some tweaks to pass travis (see review). I'm away this week if you want to make the changes yourself. Otherwise, I'll try to put this in next weekend.

d4nst commented 7 years ago

Agree with your review. Just made a commit with the changes.

bstriner commented 7 years ago

Tests look good. Thanks a lot!