Closed bis-carbon closed 5 years ago
Thank you for your opinion. Probably I think that trainable is correct with false.
I attached the pseudo code of the paper. The place where all_model is used is line 10. Here the discriminator is fixed and learning only the completion network. The discriminator update is done on line 8, and it uses d_model in my code.
I agree that d_container.trainable = False
but once you make the discriminator non trainable then you wouldn't able to train it on the following batches. The discriminator has to be trained as long as t>Tc and in order to do that I guess we need to set d_container.trainable = True
after the completion network is trained. Correct me if I am wrong and thank you for your quick response.
The algorithm on the paper is something like this:
if n < tc:
''' Train completion network '''
elif n<tc+td:
''' Train discriminator network '''
else:
''' Train both completion and discriminator '''
What I am suggesting is something like this:
if n >= tc + td:
d_container.trainable = False
all_model = Model([org_img, mask, in_pts], [cmp_out, d_container([cmp_out, in_pts])])
all_model.compile(loss=['mse', 'binary_crossentropy'], loss_weights=[1.0, alpha], optimizer=optimizer)
g_loss = all_model.train_on_batch([inputs, masks, points], [inputs, valid])
g_loss = g_loss[0] + alpha * g_loss[1]
"" the following codes makes the discriminator trainable again on the following batch ""
d_container.trainable = True
all_model.compile(loss=['mse', 'binary_crossentropy'], loss_weights=[1.0, alpha], optimizer=optimizer) ```
The following links may be helpful for your point. https://stackoverflow.com/questions/45154180/how-to-dynamically-freeze-weights-after-compiling-model-in-keras
The trainable flag is fixed in the model at compile time. So, changes to the flag after compile will not affect the compiled model.
Thank you, that clarifies my question.
Thank you for the great work. Don't you think d_container.trainable should be set to True after training all_model. Something like this: