Closed Ftiasch1 closed 7 months ago
Hi @Ftiasch1,
sorry for this late response.
In the case where ResNet-50 is the teacher and DeiT-T is the student, we transform the student feature into the feature space of the teacher for alignment. You can follow the pseudo code below to achieve the alignment and then adopt the FitNet loss.
NOTE: the below code may not be able to work with existing code without debugging. And its just a simple implementaion for applying FitNet to heterogenous models without carefullty design.
for stage in self.stages:
# get information about feature shape at each stage
_, size_t = teacher.stage_info(stage)
in_chans, H, W = size_t
_, size_s = student.stage_info(stage)
patch_num, embed_dim = size_s
# to remove cls token in student feature
token_num = getattr(student, 'num_tokens', 0)
feature_filter = TokenFilter(token_num, remove_mode=True)
# transform the student feature to have the same shape as the teacher feature
patch_grid = int((patch_num - token_num) ** .5)
if H >= patch_grid:
# the first several stages
patch_size = H // patch_grid
assert patch_size * patch_grid == H
projector = nn.Sequential(
feature_filter,
nn.Linear(embed_dim, patch_size ** 2 * in_chans),
Unpatchify(patch_size)
)
else:
# the last stage
assert patch_grid % H == 0
scale = patch_grid // H
projector = nn.Sequential(
feature_filter,
MyPatchMerging(H ** 2),
LambdaModule(
lambda x: torch.einsum('nhwc->nchw', x.view(x.size(0), H, H, x.size(-1))).contiguous()),
nn.Conv2d(embed_dim * scale ** 2, in_chans, 1, 1, 0)
)
Thanks for your reply.
Thanks to the authors for their contributions. I am facing some problems in reproducing the OFA combined with FitNet approach, e.g. the case of ResNet-50 for teacher model and DeiT-T for student model. I don't know how to start my training process and wonder if the authors can give some specific guidance. Thank you very much for your help.