naraysa / 3c-net

Weakly-supervised Action Localization
49 stars 9 forks source link

about the center-loss #6

Closed wanboyang closed 2 years ago

wanboyang commented 4 years ago

the if (labels[i] > 0).sum() == 0 or ((labels[i] > 0).sum() != 1 and itr < itr_th): continue in

def CENTERLOSS(features, logits, labels, seq_len, criterion, itr, device):
    lab = torch.zeros(0).to(device)
    feat = torch.zeros(0).to(device)
    itr_th = 5000    
    for i in range(features.size(0)):
        if (labels[i] > 0).sum() == 0 or ((labels[i] > 0).sum() != 1 and itr < itr_th):
            continue
        # categories present in the video
        labi = torch.arange(labels.size(1))[labels[i]>0]
        atn = F.softmax(logits[i][:seq_len[i]], dim=0)
        atni = atn[:,labi]
        # aggregate features category-wise
        for l in range(len(labi)):
            labl = labi[[l]].float()
            atnl = atni[:,[l]]
            atnl[atnl<atnl.mean()] = 0
            sum_atn = atnl.sum()
            if sum_atn > 0:
                atnl = atnl.expand(seq_len[i],features.size(2))
                # attention-weighted feature aggregation
                featl = torch.sum(features[i][:seq_len[i]]*atnl,dim=0,keepdim=True)/sum_atn
                feat = torch.cat([feat, featl], dim=0)
                lab = torch.cat([lab, labl], dim=0)

    if feat.numel() > 0:
        # Compute loss
        loss = criterion(feat, lab)
        return loss / feat.size(0)
    else:
        return 0

Does it mean center-loss use on multi-label training video?

naraysa commented 4 years ago

Yes. It is for multi-label training videos used for center loss.

wanboyang commented 4 years ago

“if (labels[i] > 0).sum() == 0 or ((labels[i] > 0).sum() != 1 and itr < itr_th): continue” In my opinion, It means the center-loss only work on single-label training video before the itr up to the itr_th(itr_th=5000 in this paper). Can you notice that what happened in the experiment if itr_th=0?

naraysa commented 4 years ago

Yes, we introduce the multi-label videos after a while, so that the learning of the centers is easy for training in the beginning. Generally, center loss in image classification tasks is introduced after learning the classifier to a reasonable extent. That is why we introduce the multi-label video after some iters, than at the beginning. Though I don't remember the exact values for Itr_th=0, it did have a performance drop.