shaoanlu / faceswap-GAN

A denoising autoencoder + adversarial losses and attention mechanisms for face swapping.
3.36k stars 844 forks source link

faceswap_gan_model.py - training_updates #174

Open lutzfinger opened 3 years ago

lutzfinger commented 3 years ago

I am running into issues at the line:

model.build_train_functions(loss_weights=loss_weights, **loss_config)

I am getting the following error


---------------------------------------------------------------------------

TypeError                                 Traceback (most recent call last)
<ipython-input-37-685e0523626c> in <module>()
     14 
     15 model.build_pl_model(vggface_model=vggface, before_activ=loss_config["PL_before_activ"])
---> 16 model.build_train_functions(loss_weights=loss_weights, **loss_config)

/content/faceswap-GAN/networks/faceswap_gan_model.py in build_train_functions(self, loss_weights, **loss_config)
    250         # Define training functions
    251         # Adam(...).get_updates(...)
--> 252         training_updates = Adam(lr=self.lrD*loss_config['lr_factor'], beta_1=0.5).get_updates(weightsDA,[],loss_DA)
    253         self.netDA_train = K.function([self.distorted_A, self.real_A],[loss_DA], training_updates)
    254         training_updates = Adam(lr=self.lrG*loss_config['lr_factor'], beta_1=0.5).get_updates(weightsGA,[], loss_GA)

TypeError: get_updates() takes 3 positional arguments but 4 were given

Thanks for any help.

mirfan899 commented 3 years ago

Getting the same error.

fungtion commented 3 years ago

change Adam(lr=self.lrG*loss_config['lr_factor'], beta_1=0.5).get_updates(weightsGA,[], loss_GA) to Adam(lr=self.lrG*loss_config['lr_factor'], beta_1=0.5).get_updates(loss_GA, weightsGA) may help to solve it.

varunp2k commented 3 years ago

@fungtion that did not help, any other suggestion?

ParikhKadam commented 3 years ago

Please mention the tensorflow version you are using and cross check if it matches the one listed in requirements.

varunp2k commented 3 years ago

@ParikhKadam am using tf==1.15.5 & `keras==2.1.5

ParikhKadam commented 3 years ago

@varunp2k Can you try with these? tensorflow==1.8.0 keras==2.1.5

varunp2k commented 3 years ago

@ParikhKadam Colab moved to python 3.7, tf1.8 isnt supported by python3.7