wbw520 / BotCL

Learning Bottleneck Concepts in Image Classification (CVPR 2023)
31 stars 5 forks source link

Inplement about the contrastive loss #4

Open Adasunnylily opened 10 months ago

Adasunnylily commented 10 months ago

hi, i am confused about how to compute the contrastive loss in the paper, as it mentioned in the paper to calculate lret through (t,t',y,y') , but in the code, the model returns (cpt - 0.5) 2, cls, attn, updates and seems to calculate loss by directly seeding (cpt - 0.5) 2 as y and use the pairwise_loss as lret, but the input for the pairwise_loss is pairwise_loss(y, y, label, label in the released code, which means they are similar all the time and its really confusing. Could you please tell me if there is anything wrong with my understanding? Really appreciate to your help!

def get_retrieval_loss(args, y, label, num_cls, device):
    b = label.shape[0]
    if args.dataset != "matplot":
        label = label.unsqueeze(-1)
        label = torch.zeros(b, num_cls).to(device).scatter(1, label, 1)
    similarity_loss = pairwise_loss(y, y, label, label, sigmoid_param=10. / 32)
    # similarity_loss = pairwise_loss2(y, y, label.float(), sigmoid_param=10. / 32)
    q_loss = quantization_loss(y)
    return similarity_loss, q_loss
wbw520 commented 10 months ago

Thanks for the question. We calculate the similarity of concept (y) and corresponding class label (label) among a batch data. Thus, it is a matrix similarity score. It will force image pairs from the same class to share similar concepts' distribution.