mjdietzx / SimGAN

Implementation of Apple's Learning from Simulated and Unsupervised Images through Adversarial Training
MIT License
411 stars 101 forks source link

I can not understand why keep discriminator_model.trainable=False all the time? #6

Closed AlexHex7 closed 6 years ago

AlexHex7 commented 6 years ago

The code in sim-gan.py

    refiner_model.compile(optimizer=sgd, loss=self_regularization_loss)
    discriminator_model.compile(optimizer=sgd, loss=local_adversarial_loss)
    discriminator_model.trainable = False
    combined_model.compile(optimizer=sgd, loss=[self_regularization_loss, local_adversarial_loss])

I think when pre-training the discriminator network, It should be change to True. And when process Algorithm 1, change it alternately.

Maybe I misunderstand it, can you explain it to me? Thanks!!

AlexHex7 commented 6 years ago

Sorry, I wrong....

hellojialee commented 6 years ago

HI, @AlexHex7 , could you please explain that? I guess that when you specify the trainable = True after compiling the discriminator, the trainable state will not be changed even if you use discriminator_model.trainable = False. The combined model consists of trainable refiner model and the untrainable discriminate model. So we needn't change the trainable state alternatively during training which is different to many other example such as : https://github.com/osh/KerasGAN/blob/master/MNIST_CNN_GAN_v2.ipynb

AlexHex7 commented 6 years ago

Hi, I'm so sorry about not replying you in time. In fact, I have connected the google mailbox to my school mailbox, but maybe because of the policy of my country or an important conference in my country, the mail can not be forwarded to my school mailbox.

We can find such a sentence which is about the function .eval and .train in the pytorch document -- 'This has any effect only on modules such as Dropout or BatchNorm.' ( http://pytorch.org/docs/master/nn.html?highlight=eval#torch.nn.Module.eval).

I think there are some difference between keras and pytorch. In keras, the .fit function is about training and the .predict function is about testing, so I think the operation of setting trainable state is contained in it. And we must clear that the trainable state in pytorch only effects the operations in Dropout layer and BN layer. It has nothing about updating the weights. In pytorch

for p in self.D.parameters(): p.requires_grad = False

has the same meaning as

for l in net.layers: l.trainable = val

in keras.

In my code, I do training and testing alternately. So I need to invoke .train and .eval alternately. When initializing the optimiser, I pass the parameters into it, so in fact, I do not need to set requires_grad.

self.opt_R = torch.optim.Adam(self.R.parameters(), lr=cfg.r_lr) self.opt_D = torch.optim.SGD(self.D.parameters(), lr=cfg.d_lr)

This is my understanding, if I have something wrong, do please tell me. I will appreciate it. And again, I must say that I'm so sorry about not replying you in time.

Best regard, Hex

2017-10-19 8:24 GMT+08:00 USTClj notifications@github.com:

HI, @AlexHex7 https://github.com/alexhex7 , could you please explain that? I guess that when you specify the trainable = True after compiling the discriminator, the trainable state will not be changed even if you use discriminator_model.trainable = False. The combined model consists of trainable refiner model and the untrainable discriminate model. So we needn't change the trainable state alternatively during training which is different to many other example such as : https://github.com/osh/KerasGAN/blob/master/MNIST_CNN_GAN_v2.ipynb

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/wayaai/SimGAN/issues/6#issuecomment-337764618, or mute the thread https://github.com/notifications/unsubscribe-auth/Ac9I--e2lFnHGkO8i4U4zRCxTXIPcaCjks5stpakgaJpZM4OgDJ9 .

hellojialee commented 6 years ago

@AlexHex7 Hi, thank you for your replying. I debug the process of training a GAN in keras. And I found that we must specify the trainable states before we compile the models. There are a container which keep an eye on all trainable weight during training. So, as far as I can see, we must set the discriminator.trainable = False before we compile the combined model. When we change the trainable state alternatively during training, there may be not any effects.

AlexHex7 commented 6 years ago

I get it! Thanks for your sharing.

2017-11-06 22:00 GMT+08:00 USTClj notifications@github.com:

@AlexHex7 https://github.com/alexhex7 Hi, thank you for your replying. I debug the process of training a GAN in keras. And I found that we must specify the trainable state before we compile the models. There are container which keep an eye on all trainable weight during training. So, as far as I can see, we must set the discriminator.trainable = False before we compile the combined model. When we change the trainable state alternatively during training, there may not be effects.

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/wayaai/SimGAN/issues/6#issuecomment-342156556, or mute the thread https://github.com/notifications/unsubscribe-auth/Ac9I-2c0OLpfF21C3uJrz-1ApsDXFcH6ks5szxEOgaJpZM4OgDJ9 .