HalbertCH / IEContraAST

This is the official PyTorch implementation of our paper: "Artistic Style Transfer with Internal-external Learning and Contrastive Learning".
MIT License
78 stars 7 forks source link

a question about Content contrastive loss #4

Closed tom666tom666 closed 2 years ago

tom666tom666 commented 2 years ago

Good job.

image

style_contrastive_loss += self.compute_contrastive_loss(reference_style, style_comparisons, 0.2, 0) for example , reference_style = [s1c1 ] style_comparisons = [s1c2 , s2c3,s3c4...] style_comparisons[0] is a positive sample of reference_style. But in the calculation process, why all the labels are 0? (code is loss = self.cross_entropy_loss(out, torch.tensor([index], dtype=torch.long, device=feat_q.device)) index = 0)

HalbertCH commented 2 years ago

Hi @tom666tom666, thank you for your interest in our work.

This is because for every reference_style, the positive sample is always style_comparisons[0]. We set label = 0 so that the positive sample can be the numerator of the style contrastive loss (according to the torch.nn.CrossEntropyLoss function), just like Eq.5 in the paper.

tom666tom666 commented 2 years ago

thanks for your answer