Closed TongkunGuan closed 2 months ago
Hi, thanks! Using intermediate layers may improve the performance, but it is difficult to select which layers for the losses. From my experience, forcing the earlier layers to fit too many of the teacher's intermediates can degrade the downstream task performance. The features just before the output have a good compromise.
With regards to the shapes, this is because the teacher is a transformer and so has features (B N Ct), while the student has shape (B Cs H W). Pooling out the (H W) and (N) dims seemed logical and worked quite well. However, there might definitely be a better way.
Hi, thanks! Using intermediate layers may improve the performance, but it is difficult to select which layers for the losses. From my experience, forcing the earlier layers to fit too many of the teacher's intermediates can degrade the downstream task performance. The features just before the output have a good compromise.
With regards to the shapes, this is because the teacher is a transformer and so has features (B N Ct), while the student has shape (B Cs H W). Pooling out the (H W) and (N) dims seemed logical and worked quite well. However, there might definitely be a better way.
Thanks for your answer!
Thank you for your excellent work on knowledge distillation.
It is a great idea to find a P such that P^(-1) == P^(T). In the implementation, I noticed that P is defined as:
student.projector = torch.nn.utils.parametrizations.orthogonal(nn.Linear(student.num_features, teacher.num_features, bias=False))
However, this only builds the representation loss on the last layer of features:
I want to know: 1) Can repr_distill_loss be established between more layers, such as (c2,c3,c4,c5) in ResNet50 ~~ (2,5,7,9 layer) in ViT-B; 2) Before establishing repr_distill_loss, both student features and teacher features are globally pooled. Why not just project features with the shape (B, C, H, W) and then establish the loss?