Closed tom666tom666 closed 2 years ago
Hi @tom666tom666, thank you for your interest in our work.
This is because for every reference_style, the positive sample is always style_comparisons[0]. We set label = 0 so that the positive sample can be the numerator of the style contrastive loss (according to the torch.nn.CrossEntropyLoss function), just like Eq.5 in the paper.
thanks for your answer
Good job.
style_contrastive_loss += self.compute_contrastive_loss(reference_style, style_comparisons, 0.2, 0) for example , reference_style = [s1c1 ] style_comparisons = [s1c2 , s2c3,s3c4...] style_comparisons[0] is a positive sample of reference_style. But in the calculation process, why all the labels are 0? (code is loss = self.cross_entropy_loss(out, torch.tensor([index], dtype=torch.long, device=feat_q.device)) index = 0)