roymiles / VkD

[CVPR 2024] VkD : Improving Knowledge Distillation using Orthogonal Projections
47 stars 2 forks source link

Some questions for vkd #7

Closed TongkunGuan closed 2 months ago

TongkunGuan commented 2 months ago

Thank you for your excellent work on knowledge distillation.

It is a great idea to find a P such that P^(-1) == P^(T). In the implementation, I noticed that P is defined as: student.projector = torch.nn.utils.parametrizations.orthogonal(nn.Linear(student.num_features, teacher.num_features, bias=False))

However, this only builds the representation loss on the last layer of features:

z_s_pool = z_s.mean(1)
z_s_pool = student.module.projector(z_s_pool)
z_t_conv_pool = z_t_conv.view(b, c, h * w).mean(-1)
z_t_conv_norm = F.layer_norm(z_t_conv_pool, (z_t_conv_pool.shape[1],))
repr_distill_loss = args.alpha * F.smooth_l1_loss(z_s_pool, z_t_conv_norm)

I want to know: 1) Can repr_distill_loss be established between more layers, such as (c2,c3,c4,c5) in ResNet50 ~~ (2,5,7,9 layer) in ViT-B; 2) Before establishing repr_distill_loss, both student features and teacher features are globally pooled. Why not just project features with the shape (B, C, H, W) and then establish the loss?

roymiles commented 2 months ago

Hi, thanks! Using intermediate layers may improve the performance, but it is difficult to select which layers for the losses. From my experience, forcing the earlier layers to fit too many of the teacher's intermediates can degrade the downstream task performance. The features just before the output have a good compromise.

With regards to the shapes, this is because the teacher is a transformer and so has features (B N Ct), while the student has shape (B Cs H W). Pooling out the (H W) and (N) dims seemed logical and worked quite well. However, there might definitely be a better way.

TongkunGuan commented 2 months ago

Hi, thanks! Using intermediate layers may improve the performance, but it is difficult to select which layers for the losses. From my experience, forcing the earlier layers to fit too many of the teacher's intermediates can degrade the downstream task performance. The features just before the output have a good compromise.

With regards to the shapes, this is because the teacher is a transformer and so has features (B N Ct), while the student has shape (B Cs H W). Pooling out the (H W) and (N) dims seemed logical and worked quite well. However, there might definitely be a better way.

Thanks for your answer!