osh / KerasGAN

A couple of simple GANs in Keras
501 stars 177 forks source link

make_trainable() does not freeze weights #10

Open embanner opened 7 years ago

embanner commented 7 years ago

You define a function make_trainable() which sets every layer's trainable attribute to either True or False and call this repeatedly during training. However, setting keras.layers.Layer.trainable doesn't have any effect unless you follow it up with recompiling the model. So I'm pretty sure that your layers are unfrozen during the entire training process since you only compile once.

I'll take a stab at verifying this shortly.

embanner commented 7 years ago

Confirmed that make_trainable(discriminator, False) does not actually freeze the weights.

>>> discriminator.predict(X)
array([[ 0.52295244,  0.47704756],
       [ 0.54938567,  0.45061436]], dtype=float32)
>>> make_trainable(discriminator, False)
>>> discriminator.train_on_batch(X, y)
>>> discriminator.predict(X)
array([[ 0.4992643 ,  0.50073564],
       [ 0.64071965,  0.35928035]], dtype=float32)
li-js commented 7 years ago

I think you are right. The re-compilation makes the weights frozen. Do you go further to re-compile the model inside the make_trainable() function? I perform a similar stuff, the program ends up to incrementally consume more GPU memory in every iteration until OOM error occurs. Do you have similar experience? Any help is appreciated.

embanner commented 7 years ago

Yes, I make a call to compile() inside make_trainable(). And indeed it slows things down quite a bit. I find it interesting that even without freezing the weights that the generator still produces good quality images.

vforvinay commented 7 years ago

One this I noticed, but am not totally sure of, is that by adding this line, it makes the the discriminator part of the GAN untrainable before the GAN is compiled. This would make it so that the discriminator model itself is trainable, but the discriminator part of the GAN is not, which is exactly what we would want.

5agado commented 7 years ago

With Keras 2.0.4 I tried make_trainable() as defined here, and checking with summary() I can see parameters switching from trainable to non-trainable without the need of recalling compile(). I suggest to check first this on your setup. Otherwise I also reach OOM errors if recompiling every time.