CVMI-Lab / SimGCD

(ICCV 2023) Parametric Classification for Generalized Category Discovery: A Baseline Study
https://arxiv.org/abs/2211.11727
MIT License
85 stars 13 forks source link

Eq 3, Where is pi computed? #13

Closed szalata closed 4 months ago

szalata commented 4 months ago

First of all: a very interesting work!

I looked through the code and I can't find in the implementation the computation expressed in formula (3) in the paper, that is pi. Specifically, I expect to see a dot product between normalized prototypes and hidden features.

I'd expect them in this part, before cross entropy:

                student_proj, student_out = student(images)
                teacher_out = student_out.detach()

                # clustering, sup
                sup_logits = torch.cat([f[mask_lab] for f in (student_out / 0.1).chunk(2)], dim=0)
                sup_labels = torch.cat([class_labels[mask_lab] for _ in range(2)], dim=0)
                cls_loss = nn.CrossEntropyLoss()(sup_logits, sup_labels)

                # clustering, unsup
                cluster_loss = cluster_criterion(student_out, teacher_out, epoch)
                avg_probs = (student_out / 0.1).softmax(dim=1).mean(dim=0)
                me_max_loss = - torch.sum(torch.log(avg_probs**(-avg_probs))) + math.log(float(len(avg_probs)))
                cluster_loss += args.memax_weight * me_max_loss

Instead, it seems like the logits are directly computed from the base network:

    def forward(self, x):
        x_proj = self.mlp(x)
        x = nn.functional.normalize(x, dim=-1, p=2)
        # x = x.detach()
        logits = self.last_layer(x)
        return x_proj, logits

I'd very much appreciate a clarification.

xwen99 commented 4 months ago

Hi, the logits are exactly what you are looking for. The last_layer plays the role of prototypes.

szalata commented 4 months ago

thank you, I get it now!