nupurkmr9 / vision-aided-gan

Ensembling Off-the-shelf Models for GAN Training (CVPR 2022 Oral)
https://www.cs.cmu.edu/~vision-aided-gan/
MIT License
382 stars 26 forks source link

Regarding R1 Regularization #6

Closed israrbacha closed 2 years ago

israrbacha commented 2 years ago

Thanks for sharing the fantastic work,

  1. I want to use the vision-aided loss for few-shot adaptation, should I use only cvD or combine it with the original discriminator (net_D)? aren't both combined may overfit (more params than # of training data)?
  2. how to perform R1 regularization on vision-aided discriminator (cvD) using the styleGAN2 setting? In your code, the R1 reg was performed on the original net_D.
nupurkmr9 commented 2 years ago

Hi,

  1. When training of G and net_D from scratch on few-shot images. We found the best performance is when cvD is combined with the original discriminator (net_D) with cvD being used after few warmup iterations e.g. 200k images. We have shown experiments for few-shot images in the range of 100-500 and it works without overfitting. If finetuning from a Generator pretrained on some large scale dataset e.g. on FFHQ, I have observed that only using cvD also works but the best performance might still be with using both net_D and cvD.

  2. Regarding R1 regularization on cvD. I found it to not have a significant effect on the final performance of G and therefore not included it in the final method. But, its possible that It might benefit with very few number of training images (~10-50). Not sure if the number of images in your use-case fall in this category. It should be possible by returning the cv_feat in cvD module and taking gradients of logins_cv w.r.t. to that in loss.py. I will try to update the code with the option to add R1 regularization for cvD as well.

Hope this answers your question. Let me know if you have any doubts. Thanks.

israrbacha commented 2 years ago

Thanks for the detailed response, point one is clear now. For point 2, yes I am fine-tuning the large-scale G on ~10-40 samples. The cvD logits take the shape [tensor1, tensor2,tensor3] for multilevel output. I took the last tensor from the list to calculate gradient for R1 but turned into 'nan' after some iterations. how to use this list of logits to calculate gradient penalty? sum or mean it? The final output of net_D is (batch, 1) different from cvD

nupurkmr9 commented 2 years ago

Hi, I tried R1 regularization with cvD as only CLIP-based discriminator. I made the following changes in stylegan2/training/loss.py:

# line 139
detach = True 

# line 194
r1_grads = torch.autograd.grad(outputs=[each.sum() for each in logits_cv[0]], inputs=[real_img_tmp],
                                                       create_graph=True, only_inputs=True, allow_unused=True)[0] 

I don't face 'nan' issues early in the training. But will post here if the results are significantly different or face any issues later on in the training.

Thanks.

israrbacha commented 2 years ago

Thanks a lot! Waiting for the result

nupurkmr9 commented 2 years ago

Hi,

I was able to train the model without any 'nan' issues.