roymiles / VkD

[CVPR 2024] VkD : Improving Knowledge Distillation using Orthogonal Projections
45 stars 2 forks source link

Implementation details for Data limited image generation #3

Closed hjinnkim closed 3 months ago

hjinnkim commented 3 months ago

First of all, thank you for your impressive work!

I have a few questions about Data limited image generation.

  1. In your paper, you compared your method with KD-DLGAN.

    I guess you distilled pre-trained BigGAN to randomly initialized BigGAN with Data limited regime.

    As long as I know, KD-DLGAN aims at improving training GAN with aid of CLIP.

    It is little bit confusing to me because the teacher model was changed. Do I correctly understand?

  2. Could you explain how you distill BigGAN? I cannot find which feature is distilled from the teacher model.

    If you are able to share example code, it would be nice.

Thank you for advance!

roymiles commented 3 months ago

Hi HyeonJin!

We just modified the code provided by KD-DLGAN by removing some of the unnecessary losses and adding in our orthogonal feature loss. We used the same CLIP teacher as KD-DLGAN.

For example, all the distillation losses are given in this file here

  1. We removed the out of distribution text distillation loss i.e. set D_loss_outdistri_div = torch.tensor(0.0)
  2. We also removed the other text losses i.e. the added img_lan_similarity_loss(...) components on L71 and L73

We then introduce our additional feature distillation loss like this:

with torch.no_grad():
    real_images = nn.Upsample(scale_factor=7, mode='nearest')(real_images)
    latent_real = CLIP.encode_image(real_images)

    fake_images = nn.Upsample(scale_factor=7, mode='nearest')(fake_images)
    latent_fake = CLIP.encode_image(fake_images)

...

# orthogonal projection, shared for both.
D_real_out_reg = G.projector(D_real_out_reg)
D_fake_out_reg = G.projector(D_fake_out_reg)

# e.g. for layer norm
latent_real_norm = (latent_real - latent_real.mean(1, keepdim=True)) / latent_real.std(1, keepdim=True)
latent_fake_norm = (latent_fake - latent_fake.mean(1, keepdim=True)) / latent_fake.std(1, keepdim=True)

D_loss_real_reg += F.smooth_l1_loss(latent_real_norm, D_real_out_reg) * weighting
D_loss_fake_reg += F.smooth_l1_loss(latent_fake_norm, D_fake_out_reg) * weighting
hjinnkim commented 3 months ago

Thank you for the kind response! This helps me a lot!