madebyollin / taesd

Tiny AutoEncoder for Stable Diffusion
MIT License
494 stars 27 forks source link

Finetuning taesd but get dim and less saturated images #22

Open stardusts-hj opened 2 weeks ago

stardusts-hj commented 2 weeks ago

Hi madebyollin!

Thank you so much for your awesome work and kind reply. Sorry to bother you again. While trying to finetune taesd decoder, I found that the images are getting a little dim and less saturated. I suspect the reason may be the different datasets or data processing strategies I've used. The example is shown below.

  1. For datasets, I use the same laion ae dataset which I also find the images are less saturated since lots of images have a white background and a single object in the center. The images are resized to 512x512. I adopt the color augmentation you suggested, but the output is still less staturated than taesd. It would be kind if you could give some suggestions for dataset choice or data augmentation, will it help if I add more colorful datasets like [danbooru2021]?

  2. Besides, my training strategy is to train taesd decoder with lpips loss and gan loss. Do you think which one matters more to enhance the output quality? Will the lpips loss affect the saturation of the images? If so, how about using a GAN loss only ?

  3. It seems that taesd version 1.2 is trained based on the weights from the previous version. Could you please share some details about the finetuning (did you initialize discriminator from previous version? why did you remove the lpips loss in version 1.2?)

I sincerely appreciate your wonderful work and enthusiasm. It will be really great if you could provide me some suggestions for finetuning taesd. Many thanks to you.

(left taesd, right my finetuning version)

stardusts-hj commented 2 weeks ago

Another question is that since I only finetune the taesd decoder and leave the encoder freezed, should I also train the encoder of taesd with my datasets?

madebyollin commented 2 weeks ago
  1. "will it help if I add more colorful datasets" - depends on where the problem is. a. If your decoder generates dim/desaturated images even during training, then it's a loss problem (try adding some low-res MSE loss, like low_res_mse_loss = F.mse_loss(F.avg_pool2d(decoded, 8), F.avg_pool2d(real_images, 8))). b. If your decoder generates saturated images during training, but dim/desaturated images during eval on the training dataset, then it's a train/eval mismatch problem (or maybe a training resolution problem, iirc you need to train on at least 256x256 images to get correct colors). c. If your decoder generates saturated images during training, and saturated during eval on training examples, but desaturated images during eval on test data, then it is a dataset problem (try adding more saturated images to the dataset).
  2. "lpips loss and gan loss. Do you think which one matters more " - The losses do different things. For training the decoder, it's possible to use only GAN loss as long as you condition the GAN discriminator on the latents, but in practice you probably want to include LPIPS and low-res MSE to make training converge faster.
  3. TAESD 1.2 used a newly-trained conditional discriminator (like the Seraena one) so I didn't need to use LPIPS. LPIPS was wasting a bunch of memory / time so I decided to remove it.
  4. "should I also train the encoder of taesd with my datasets" probably not. You might be able to get slightly better reconstructions by finetuning the encoder, but it won't fix sharpness or saturation issues.
stardusts-hj commented 2 weeks ago

Thank you so much for your helpful suggestions and quick reply! I observe the degradation of images even during training so I might again check my training loss (1.a). For the low-res MSE loss (2.), it seems like taesd decoder is treated as a conditional GAN, which is a generative model, so it may be unnecessary to leverage pixel-wise reconstruction loss. I will focus on the GAN part and try to improve it. I really appreciate your suggestion about how to locate the problem. Thank you!