Closed dendrobiumz closed 4 years ago
I try to apply FixMatch on one class data. For the unsupervised loss part, I modified the code like that.
logits_u_w, logits_u_s = logits[batch_size:].chunk(2) pseudo_label = torch.sigmoid( logits_u_w.detach_() ) mask = pseudo_label.ge( args.threshold ).float() Lu = (F.binary_cross_entropy( logits_u_s, mask, reduction='none' ) * mask ).mean()
logits_u_w, logits_u_s = logits[batch_size:].chunk(2)
pseudo_label = torch.sigmoid( logits_u_w.detach_() )
mask = pseudo_label.ge( args.threshold ).float()
Lu = (F.binary_cross_entropy( logits_u_s, mask, reduction='none' ) * mask ).mean()
Is that correct?
targets_u = pseudo_label.ge(0.5) Lu = (F.binary_cross_entropy( logits_u_s, targets_u, reduction='none' ) * mask ).mean()
I try to apply FixMatch on one class data. For the unsupervised loss part, I modified the code like that.
logits_u_w, logits_u_s = logits[batch_size:].chunk(2)
pseudo_label = torch.sigmoid( logits_u_w.detach_() )
mask = pseudo_label.ge( args.threshold ).float()
Lu = (F.binary_cross_entropy( logits_u_s, mask, reduction='none' ) * mask ).mean()
Is that correct?