vasily789 / adaptive-weighted-gans

12 stars 4 forks source link

No Issues: A example of using this loss in GAN? #2

Open sumorday opened 2 years ago

sumorday commented 2 years ago

Thank you so much!:)

sumorday commented 2 years ago

https://colab.research.google.com/drive/1AsZztd0Af0UMzBXXkI9QKQZhAUoK01bk?usp=sharing

here is my code it seems i can not class the aw-loss succeed

vasily789 commented 2 years ago

Dear sumorday,

Thank you for your interest in our method. I took a quick look at your code and here are some questions/recommendations.

Are you able to get any results using the standard method, i.e. not our aw-method. I am just curious since you are using the BCEWithLogitsLoss loss function as in the original implementation.

I see that you have made some changes to the aw-method, in particular, you are calculating rs and fs scores using losses instead of pure validities that come from Discriminator, but you still are taking sigmoid of your BCEWithLogitsLoss losses, which will keep rs/fs always above 0.5. I would recommend either using our implementation with mean(sigmoid(validity)) or you might need to adjust the decision tree correspondently.

There are several inputs that aw-method required because of that I do believe you had issues calling on it.

Best, Vasily

sumorday commented 2 years ago

Dear Vasily,

Hi! Thank you for your prompt reply! BCEWithLogitsLoss loss function can work without any issues. this is a DCGAN model from Kaggle(https://www.kaggle.com/vatsalmavani/deep-convolutional-gan-in-pytorch) or Maybe I should use the hinge loss according the the aw method, right?

I changed here and the def aw_loss(self, Dloss_real, Dloss_fake, Dis_opt, Dis_Net, real_validity, fake_validity) rs = torch.mean(torch.sigmoid(real_validity)) fs = torch.mean(torch.sigmoid(fake_validity))

Because I am not sure what is real_validity/fake_validity, I guess it means the loss of discriminator in real/fake... I have no ideas how to change and class it correctly....

Also I Asked the question on stack overflow because I am poor in programming. still no ideas... https://stackoverflow.com/questions/71334763/how-to-class-another-loss-function-in-gan-discriminator/71338620#71338620

How can I do....

Best, Edward(sumorday)

HanAccount commented 2 years ago

Hi sumorday, I think the real_validity maybe equal to D(real_image). That is real_validity = D(real_image) fake_validity = D(fake_image). I've already used this method to run the code you gave above. 1

sumorday commented 2 years ago

Dear Han:

Thank you so much! Nice to hear that. I tried to add real_validity = D(real_image) fake_validity = D(fake_image) And def aw_loss(self, D_real_loss, D_fake_loss, D_opt, D, _real_validity, fakevalidity):

I guess Dis_Net is equal to D, which mentioned in previous codes. D = Discriminator().to(device) or D = D.apply(weights_init)

Then I tried to use your code D_loss = aw_method.aw_loss(D_real_loss,D_fake_loss,D_opt,real_validity=real_validity,fake_validity=fake_validity)

TypeError: aw_loss() missing 2 required positional arguments: 'D_opt' and 'D'

...I have no idea about this. https://colab.research.google.com/drive/17YJuOFKMGujJNNQ4KLU-_WhRIIhl_cqt#scrollTo=eaVc1nhLAObB

I really hope that author can please share a toy GAN code like using mnist or cifar dataset ... I guess that's because I am lacking of programming. ..

HanAccount commented 2 years ago

I took a look at your code,I have found you wrong in two places. 11

That should be aw = aw_method() D_loss = aw.aw_loss(D_real_loss,D_fake_loss,D_opt,real_validity=real_validity,fake_validity=fake_validity) or D_loss = aw_method().aw_loss(D_real_loss,D_fake_loss,D_opt,real_validity=real_validity,fake_validity=fake_validity)

and the aw_loss() is missing a parameter D. It should end up being D_loss = aw_method().aw_loss(D_real_loss,D_fake_loss,D_opt,D,real_validity=real_validity,fake_validity=fake_validity)

sumorday commented 2 years ago

Hi ! D = Discriminator() is this right? when I added this into the epoch, there's a TypeError: aw_loss() missing 1 required positional argument: 'D' https://colab.research.google.com/drive/17YJuOFKMGujJNNQ4KLU-_WhRIIhl_cqt#scrollTo=eaVc1nhLAObB

HanAccount commented 2 years ago

yes, all right! D is Discriminator network

qingyuany commented 11 months ago

yes, all right! D is Discriminator network

同学,您好。感谢的你精彩回复,给了我很大的帮助。我最近也在学习GAN网络的应用。请问。您是否能够提供一个aw_loss在条件GAN网络中的应用实例。期待您的回复