irfanICMLL / structure_knowledge_distillation

The official code for the paper 'Structured Knowledge Distillation for Semantic Segmentation'. (CVPR 2019 ORAL) and extension to other tasks.
BSD 2-Clause "Simplified" License
702 stars 104 forks source link

How to implement the Pair-wise loss by tensorflow #1

Closed cs-heibao closed 5 years ago

cs-heibao commented 5 years ago

How do I implement the Pair-wise distillation (Pair-wise loss) by tensorflow, especially the computation of aij , thanks!

irfanICMLL commented 5 years ago

You can refer to the Class 'Cos_Attn' in utils/criterion

class Cos_Attn(nn.Module): """ Self attention Layer"""

def __init__(self, activation):
    super(Cos_Attn, self).__init__()
    # self.chanel_in = in_dim
    self.activation = activation
    self.softmax = nn.Softmax(dim=-1)  #

def forward(self, x):
    """
        inputs :
            x : input feature maps( B X C X W X H)
        returns :
            out : self attention value + input feature
            attention: B X N X N (N is Width*Height)
    """

    m_batchsize, C, width, height = x.size()
    proj_query = x.view(m_batchsize, -1, width * height).permute(0, 2, 1)  # B X CX(N)
    proj_key = x.view(m_batchsize, -1, width * height)  # B X C x (*W*H)
    q_norm = proj_query.norm(2, dim=2)
    nm = torch.bmm(q_norm.view(m_batchsize, width * height, 1), q_norm.view(m_batchsize, 1, width * height))
    energy = torch.bmm(proj_query, proj_key)  # transpose check
    norm_energy = energy / nm
    attention = self.softmax(norm_energy)  # BX (N) X (N)
    return attention

You can replace the related API to compute the aij in tensorflow.

steermomo commented 5 years ago

Hi, Could you explain why use softmax here? norm_energy should be the similarity matrix. What is the attention here, it seems not mentioned in the paper. Thanks. : )

irfanICMLL commented 5 years ago

There are three version in utils/criterion, 'softmax', 'sigmoid' and 'no activation'. It is just used for normalise the scale of the loss. You can use the norm_energy directly. We found these three ways are similar after adjusting the loss weight.

zhLawliet commented 4 years ago

@cs-heibao hi,do you have implemented the Pair-wise distillation (Pair-wise loss) by tensorflow? can you share it for me, thanks