Hao840 / OFAKD

PyTorch code and checkpoints release for OFA-KD: https://arxiv.org/abs/2310.19444
81 stars 11 forks source link

How do I combine FitNet when using OFA? #5

Closed Ftiasch1 closed 7 months ago

Ftiasch1 commented 7 months ago

Thanks to the authors for their contributions. I am facing some problems in reproducing the OFA combined with FitNet approach, e.g. the case of ResNet-50 for teacher model and DeiT-T for student model. I don't know how to start my training process and wonder if the authors can give some specific guidance. Thank you very much for your help.

Hao840 commented 7 months ago

Hi @Ftiasch1,

sorry for this late response.

In the case where ResNet-50 is the teacher and DeiT-T is the student, we transform the student feature into the feature space of the teacher for alignment. You can follow the pseudo code below to achieve the alignment and then adopt the FitNet loss.

NOTE: the below code may not be able to work with existing code without debugging. And its just a simple implementaion for applying FitNet to heterogenous models without carefullty design.

for stage in self.stages:
    # get information about feature shape at each stage
    _, size_t = teacher.stage_info(stage)
    in_chans, H, W = size_t

    _, size_s = student.stage_info(stage)
    patch_num, embed_dim = size_s

    # to remove cls token in student feature
    token_num = getattr(student, 'num_tokens', 0)
    feature_filter = TokenFilter(token_num, remove_mode=True)

    # transform the student feature to have the same shape as the teacher feature
    patch_grid = int((patch_num - token_num) ** .5)
    if H >= patch_grid:
        # the first several stages
        patch_size = H // patch_grid
        assert patch_size * patch_grid == H
        projector = nn.Sequential(
            feature_filter,
            nn.Linear(embed_dim, patch_size ** 2 * in_chans),
            Unpatchify(patch_size)
        )
    else:
        # the last stage
        assert patch_grid % H == 0
        scale = patch_grid // H
        projector = nn.Sequential(
            feature_filter,
            MyPatchMerging(H ** 2),
            LambdaModule(
                lambda x: torch.einsum('nhwc->nchw', x.view(x.size(0), H, H, x.size(-1))).contiguous()),
            nn.Conv2d(embed_dim * scale ** 2, in_chans, 1, 1, 0)
        )
Ftiasch1 commented 7 months ago

Thanks for your reply.