YuvalBahat / Explorable-Super-Resolution

Apache License 2.0
62 stars 8 forks source link

question about discriminator success rate #11

Closed usstdqq closed 3 years ago

usstdqq commented 3 years ago

Thanks for sharing the training/testing code first!

I've been tryinng to train the ExplorableSR on some other dataset. I found that the generator_step flag is always False after trainning for a few hours. After some debuging, I think it is due to the discriminator's success rate is lower than the default 0.9 threashold.

I am wondering how many steps are usually needed to get the discriminator reach the 0.9 success rate? It would be helpful to provide a pretrained discriminator to speed up the training process.

YuvalBahat commented 3 years ago

Thank you for your feedback. Stabilizing training of the Explorable SR GAN is a bit challenging, and we are currently still working to improve this. In the meantime, I suggest you try training without the discriminator verification mechanism, by simply commenting out this line.

Following your suggestion, I also added a link to the pre-trained discriminator corresponding to the provided pre-trained explorable SR model, so that you can initialize both generator and discriminator models to the provided pre-trained ones, and fine-tune them from there using your own dataset.

I hope this can help.

usstdqq commented 3 years ago

Thanks for providing the pre-trained discriminator!

When loadding the provided pre-trained discriminator, I found there is some state_dict mismatch. The provided model has such keys: (Pdb) loaded_state_dict.keys() odict_keys(['features.0.weight', 'features.0.bias', 'features.2.weight', 'features.2.bias', 'features.3.weight', 'features.3.bias', 'features.3.running_mean', 'features.3.running_var', 'features.3.num_batches_tracked', 'features.5.weight', 'features.5.bias', 'features.6.weight', 'features.6.bias', 'features.6.running_mean', 'features.6.running_var', 'features.6.num_batches_tracked', 'features.8.weight', 'features.8.bias', 'features.9.weight', 'features.9.bias', 'features.9.running_mean', 'features.9.running_var', 'features.9.num_batches_tracked', 'features.11.weight', 'features.11.bias', 'features.12.weight', 'features.12.bias', 'features.12.running_mean', 'features.12.running_var', 'features.12.num_batches_tracked', 'features.14.weight', 'features.14.bias', 'features.15.weight', 'features.15.bias', 'features.15.running_mean', 'features.15.running_var', 'features.15.num_batches_tracked', 'features.17.weight', 'features.17.bias', 'features.18.weight', 'features.18.bias', 'features.18.running_mean', 'features.18.running_var', 'features.18.num_batches_tracked', 'features.20.weight', 'features.20.bias', 'features.21.weight', 'features.21.bias', 'features.21.running_mean', 'features.21.running_var', 'features.21.num_batches_tracked', 'features.23.weight', 'features.23.bias', 'features.24.weight', 'features.24.bias', 'features.24.running_mean', 'features.24.running_var', 'features.24.num_batches_tracked', 'features.26.weight', 'features.26.bias', 'features.27.weight', 'features.27.bias', 'features.27.running_mean', 'features.27.running_var', 'features.27.num_batches_tracked', 'classifier.0.weight', 'classifier.0.bias', 'classifier.2.weight', 'classifier.2.bias'])

while the defined netD has such keys: (Pdb) netD.state_dict().keys() odict_keys(['features.0.weight', 'features.0.bias', 'features.2.weight', 'features.2.bias', 'features.3.weight', 'features.3.bias', 'features.3.running_mean', 'features.3.running_var', 'features.3.num_batches_tracked', 'features.5.weight', 'features.5.bias', 'features.6.weight', 'features.6.bias', 'features.6.running_mean', 'features.6.running_var', 'features.6.num_batches_tracked', 'features.8.weight', 'features.8.bias', 'features.9.weight', 'features.9.bias', 'features.9.running_mean', 'features.9.running_var', 'features.9.num_batches_tracked', 'features.11.weight', 'features.11.bias', 'features.12.weight', 'features.12.bias', 'features.12.running_mean', 'features.12.running_var', 'features.12.num_batches_tracked', 'features.14.weight', 'features.14.bias', 'features.15.weight', 'features.15.bias', 'features.15.running_mean', 'features.15.running_var', 'features.15.num_batches_tracked', 'classifier.0.0.weight', 'classifier.0.0.bias', 'classifier.0.1.weight', 'classifier.0.1.bias', 'classifier.0.1.running_mean', 'classifier.0.1.running_var', 'classifier.0.1.num_batches_tracked', 'classifier.2.0.weight', 'classifier.2.0.bias', 'classifier.2.1.weight', 'classifier.2.1.bias', 'classifier.2.1.running_mean', 'classifier.2.1.running_var', 'classifier.2.1.num_batches_tracked']) It seems that there are more layers(features) and some different calssification heads in the provided model. I am using the provided config for netD, the opt_net looks like {'which_model_D': 'discriminator_vgg_128', 'relativistic': 0, 'decomposed_input': 0, 'pre_clipping': 0, 'add_quantization_noise': 0, 'norm_type': 'batch', 'act_type': 'leakyrelu', 'mode': 'CNA', 'n_layers': 6, 'in_nc': 3, 'nf': 64} should I use some other config for netD, or there is a keys mapping from the provided netD to the defined netD?

YuvalBahat commented 3 years ago

Thank you for letting me know about this problem. I'm currently debugging it, and will let you know once this problem is fixed.

YuvalBahat commented 3 years ago

I found the problem and pushed the fix. There was one line in the code that needed to be commented, and the number of discriminator layers in the configuration file had to be set back to 10. Please let me know if you run into additional problems.

usstdqq commented 3 years ago

Hello, do you mean comment out this line of code? https://github.com/YuvalBahat/Explorable-Super-Resolution/commit/b17c856351ac28d373d08520f3cac8b5a4d1d88c#diff-ff9a5453e4fe531d6f7671c949e9dd0eb54b92761a63a32cf210253501cf279aR514 just found you have multiple working branches : )

usstdqq commented 3 years ago

hmm, the netD loading seems working now. Thx!