kekmodel / FixMatch-pytorch

Unofficial PyTorch implementation of "FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence"
MIT License
758 stars 170 forks source link

Unsupervised loss part for one single class #11

Closed dendrobiumz closed 4 years ago

dendrobiumz commented 4 years ago

I try to apply FixMatch on one class data. For the unsupervised loss part, I modified the code like that.

logits_u_w, logits_u_s = logits[batch_size:].chunk(2) pseudo_label = torch.sigmoid( logits_u_w.detach_() ) mask = pseudo_label.ge( args.threshold ).float() Lu = (F.binary_cross_entropy( logits_u_s, mask, reduction='none' ) * mask ).mean()

Is that correct?

kekmodel commented 4 years ago

targets_u = pseudo_label.ge(0.5) Lu = (F.binary_cross_entropy( logits_u_s, targets_u, reduction='none' ) * mask ).mean()