nagadomi / nunif

Misc; latest version of waifu2x; 2D video to SBS 3D video
MIT License
1.24k stars 118 forks source link

Any suggestion to modify the arch based on the gan training result? #161

Open 3zhang opened 3 weeks ago

3zhang commented 3 weeks ago

I'm training a photo swin_unet_2x model using gan. I use a cosine lr scheduler with init lr = 1e-5. After some tries I found that the discriminator loss fluctuated around 0.8 (the threshold for generator training to begin), so I increased discriminator lr to 5e-5. And this is the result.

屏幕截图 2024-06-27 161034 屏幕截图 2024-06-27 162529 屏幕截图 2024-06-27 162155

After ~160 epochs the gen loss start to increase which trade off with the decrease of discr loss, which is not good. My guess is that maybe the gen model is undergoing some underfitting? So could you give me some suggestion to modify the arch to make the model more complex? Or should I try with a different arch?

nagadomi commented 3 weeks ago

I plan to try again in this field in the next month or so.

In my application, the generator (SR model) has VRAM usage and inference time limitations so the model cannot be made complex. So, for my part, I think a smaller discrminator or some other technique will be necessary. (For experiments there are larger models such as waifu2x.swin_unet_4xl, waifu2x.winc_unet_2xl, waifu2x.winc_unet_4xl but need to do full scratch training.)

The other problem I see is that the disc_weight keeps decreasing, I plan to add an option to use a fixed weight instead of using adaptive weight.

Also, because of these problems, I have only run up to epoch=40-80(with --num-samples 25000) in my current release of the model.

3zhang commented 2 weeks ago

I plan to try again in this field in the next month or so.

In my application, the generator (SR model) has VRAM usage and inference time limitations so the model cannot be made complex. So, for my part, I think a smaller discrminator or some other technique will be necessary. (For experiments there are larger models such as waifu2x.swin_unet_4xl, waifu2x.winc_unet_2xl, waifu2x.winc_unet_4xl but need to do full scratch training.)

The other problem I see is that the disc_weight keeps decreasing, I plan to add an option to use a fixed weight instead of using adaptive weight.

Also, because of these problems, I have only run up to epoch=40-80(with --num-samples 25000) in my current release of the model.

Could you briefly explain what l3v1 discriminator is?

nagadomi commented 2 weeks ago

It is defined in waifu2x/models/discriminator.py. It is a PatchGAN type discriminator, with two branches: shallow(v1) and deep(l3).

3zhang commented 2 weeks ago

I plan to try again in this field in the next month or so.

In my application, the generator (SR model) has VRAM usage and inference time limitations so the model cannot be made complex. So, for my part, I think a smaller discrminator or some other technique will be necessary. (For experiments there are larger models such as waifu2x.swin_unet_4xl, waifu2x.winc_unet_2xl, waifu2x.winc_unet_4xl but need to do full scratch training.)

The other problem I see is that the disc_weight keeps decreasing, I plan to add an option to use a fixed weight instead of using adaptive weight.

Also, because of these problems, I have only run up to epoch=40-80(with --num-samples 25000) in my current release of the model.

About the adaptive weight. When adaptive weight is small, does it mean that the gradient of the generator loss is much greater than the gradient of the reconstruction loss? So its purpose is to limit the gan so that it does not change too fast to break the psnr model, right?

nagadomi commented 2 weeks ago

Ideally, yes.

weight = l2norm(last_layer_gradient(recon_loss)) / l2norm(last_layer_gradient(gen_loss))
loss = recon_loss / weight + gen_loss

However, in practice, when the discriminator gets stronger, weight decrease and psnr increases. I think it means that 1/weight become too much larger than ideal.

I referenced taming-transformers for GAN loss (L1+LPIPS+GAN), so I adopted adaptive weight.

3zhang commented 2 weeks ago

Ideally, yes.

weight = l2norm(last_layer_gradient(recon_loss)) / l2norm(last_layer_gradient(gen_loss))
loss = recon_loss / weight + gen_loss

However, in practice, when the discriminator gets stronger, weight decrease and psnr increases. I think it means that 1/weight become too much larger than ideal.

I referenced taming-transformers for GAN loss (L1+LPIPS+GAN), so I adopted adaptive weight.

I see you use Hinge loss for discriminator. Have you tried other loss like Wasserstein Loss?

nagadomi commented 2 weeks ago

No, I do not have much GAN experience. My first try was this(DCGAN Binary Cross Entropy) https://github.com/nagadomi/nunif/tree/master/playground/gan and the next was waifu2x (and for the most part, I intended to follow taming-transformers. I first checked it works with 4x SR task using CelebA dataset)

3zhang commented 3 days ago

No, I do not have much GAN experience. My first try was this(DCGAN Binary Cross Entropy) https://github.com/nagadomi/nunif/tree/master/playground/gan and the next was waifu2x (and for the most part, I intended to follow taming-transformers. I first checked it works with 4x SR task using CelebA dataset)

Does waifu2x support self-supervised discriminator?

nagadomi commented 3 days ago

I have tried an autoencoder type self supervised discriminator(refer FastGAN) but I have not commit the model code.