CVMI-Lab / SimGCD

(ICCV 2023) Parametric Classification for Generalized Category Discovery: A Baseline Study
https://arxiv.org/abs/2211.11727
MIT License
85 stars 13 forks source link

Loss functions do not match the paper? #11

Closed grypesc closed 9 months ago

grypesc commented 9 months ago

Hi, in your paper you state that you use two loss functions to train the feature extractor (SupCL and SelfSupCL). But in your code I can see 4 losses where there is a cross entropy loss between targets and outputs of the feature extractor:

train.py:

                # clustering, sup
                sup_logits = torch.cat([f[mask_lab] for f in (student_out / 0.1).chunk(2)], dim=0)
                sup_labels = torch.cat([class_labels[mask_lab] for _ in range(2)], dim=0)
                cls_loss = nn.CrossEntropyLoss()(sup_logits, sup_labels)

                # clustering, unsup
                cluster_loss = cluster_criterion(student_out, teacher_out, epoch)
                avg_probs = (student_out / 0.1).softmax(dim=1).mean(dim=0)
                me_max_loss = - torch.sum(torch.log(avg_probs**(-avg_probs))) + math.log(float(len(avg_probs)))
                cluster_loss += args.memax_weight * me_max_loss

                # represent learning, unsup
                contrastive_logits, contrastive_labels = info_nce_logits(features=student_proj)
                contrastive_loss = torch.nn.CrossEntropyLoss()(contrastive_logits, contrastive_labels)

                # representation learning, sup
                student_proj = torch.cat([f[mask_lab].unsqueeze(1) for f in student_proj.chunk(2)], dim=1)
                student_proj = torch.nn.functional.normalize(student_proj, dim=-1)
                sup_con_labels = class_labels[mask_lab]
                sup_con_loss = SupConLoss()(student_proj, labels=sup_con_labels)

This loss also seems to have effect on the feature extractor, can you elaborate?

DTennant commented 9 months ago

Hi, our method SimGCD is introduced in the section 4 of our paper, and from the section 4.1 and section 4.2, you can see the four losses we used to train the model.

grypesc commented 9 months ago

Thanks for clarification, I misunderstood that one