lenscloth / RKD

Official pytorch Implementation of Relational Knowledge Distillation, CVPR 2019
390 stars 49 forks source link

in loss design , is the code " torch.no_grad " essential ? #16

Closed leoluopy closed 4 years ago

leoluopy commented 4 years ago

Hi, glad to see u , i am reading your loss design now , and found code below ` class RKdAngle(nn.Module): def forward(self, student, teacher):

N x C

    # N x N x C

    with torch.no_grad():
        td = (teacher.unsqueeze(0) - teacher.unsqueeze(1))
        norm_td = F.normalize(td, p=2, dim=2)
        t_angle = torch.bmm(norm_td, norm_td.transpose(1, 2)).view(-1)

    sd = (student.unsqueeze(0) - student.unsqueeze(1))
    norm_sd = F.normalize(sd, p=2, dim=2)
    s_angle = torch.bmm(norm_sd, norm_sd.transpose(1, 2)).view(-1)

    loss = F.smooth_l1_loss(s_angle, t_angle, reduction='elementwise_mean')
    return loss

` both in rkd angle and rkd distance , there is a " torch.no_grad" in teacher related code . is that essential ? can that be removed ?

lenscloth commented 4 years ago

@leoluopy Hello,

"torch.no_grad" is not necessary. Since the teacher model is already forwarded in the context of "torch.no_grad"

with torch.no_grad():
            t_b1, t_b2, t_b3, t_b4, t_pool, t_e = teacher(teacher_normalize(images), True)