ManiadisG / ExCB

2 stars 0 forks source link

I have some issues about cemtering #2

Open LYJhere opened 6 days ago

LYJhere commented 6 days ago

I have some issues in following codes: @torch.no_grad() def centering(self, teacher_output):
labels_pre = teacher_output.argmax(-1) lower_w, higher_w, = self.centering_modifiers() teacher_output = 1-(-teacher_output+1) lower_w # Decreasing the cosine DISTANCE for small clusters teacher_output = (teacher_output+1) higher_w - 1 # Decreasing the cosine SIMILARITY for large clusters teacher_output_argmax = teacher_output.argmax(-1) teacher_output_argmax_oh = F.one_hot(teacher_output_argmax, self.out_dim)

    # boosting choices
    bt = torch.ones(teacher_output.shape).type_as(teacher_output)+teacher_output_argmax_oh*0.01
    teacher_output = teacher_output * bt
    teacher_output = torch.clamp(teacher_output, -1, 1)
    labels_post = teacher_output.argmax(-1)
    changed_labels = 1-(labels_post==labels_pre).float().mean()
    return teacher_output, changed_labels, {"llb_min": lower_w.min(), "lhb_min": higher_w.min()}

def centering_modifiers(self):
    center = F.normalize(self.center,p=1,dim=-1)*self.out_dim
    lower_w = 1 - F.relu(1-center)
    lower_w += torch.rand(lower_w.shape, device=self.center.device)*0.001*(lower_w==0).float() # To prevent from going to 0
    higher_w = 1 - F.relu(1-1/center)
    higher_w += torch.rand(higher_w.shape, device=self.center.device)*0.001*(higher_w==0).float() # To prevent from going to 0
    return lower_w, higher_w

Does the lower_w in the code mean sk, higher_w in the code mean 1/sk ?

1

I can not find the comparation between S and 1/K in the code but depicted in your paper.

ManiadisG commented 6 days ago

Hello.

The operation in the code is (in practice) the same as the one described in the paper, but the implementation is different for reasons of efficiency. In effect, the two cases (s>1/K and s<1/K) are implemented with different weight vectors, as you note in your comment. The comparison specifically is "represented" in the relu operators.