madebyollin / taesd

Tiny AutoEncoder for Stable Diffusion
MIT License
495 stars 27 forks source link

How do you train the TAESD's latent space shared with SD VAE? #16

Closed ivylilili closed 3 months ago

ivylilili commented 3 months ago

Hi Ollin~ I have read a lot about your discussion in Reddit, github gist, and open-sourced project about the VAE in the SD/SDXL Thanks for sharing! I know you are an expert in the VAE study and I'd like to ask you some questions. As to the TAESD, I learned from your answers in other issues that "TAESD received the image input in [0, 1] and has the scale_factor=1 rather than 0.18215 in SD'. But your model has a totally different structure (it's simpler) than SD. So I wonder how do you train the TAESD ? From Scratch or Taking reference to some weights in the SD? Besides, I noticed that you mentioned that the TAESD's latent space is compatible with SD, only with the scaling_factor to adjust. This really surprised me and what magic did you take to make these two latent spaces compatible? I'm new to the distillation and would you be kind to tell how you realized this? The official SD' VAE is trained with a small KL-weight 1e-6 and it indeed has a large variance distribution in the latent space (std=1/0.18215) and then SDXL has a larger variance (std=1/0.13025). And your TAESD has a small variance (std=1/1). I have also trained the SD'VAE, and according to the experiment, the std of latent space tends to decrease (the kl loss is also decreasing). Does your TAESD have the same tendency?

madebyollin commented 3 months ago

how do you train the TAESD ? From Scratch or Taking reference to some weights in the SD?

TAESD weights are initialized from scratch, but TAESD is supervised using outputs from the SD VAE.

what magic did you take to make these two latent spaces compatible

TAESD is just trained to directly mimic the SD encoder / decoder, without using any KL loss or whatever. In pseudocode:

def taesd_training_loss(real_images, taesd, sdvae):
    taesd_latents = taesd.encoder(real_images)
    sdvae_latents = apply_scaling_factor(sdvae.encoder(real_images))
    taesd_images = taesd.decoder(sdvae_latents)

    # in the ideal case latent_loss and image_loss would be F.mse_loss,
    # but in practice they're some mix of adversarial and perceptual loss terms
    return latent_loss(taesd_latents, sdvae_latents) + image_loss(taesd_images, real_images)

https://github.com/madebyollin/taesd/issues/11#issuecomment-1914990359 has more detail about the approach used

according to the experiment, the std of latent space tends to decrease (the kl loss is also decreasing). Does your TAESD have the same tendency?

TAESD is entirely deterministic and doesn't use any KL loss. So, no :)

ivylilili commented 3 months ago

Thanks for your response. It really helps!