Haochen-Wang409 / U2PL

[CVPR'22] Semi-Supervised Semantic Segmentation Using Unreliable Pseudo-Labels
Apache License 2.0
426 stars 59 forks source link

Possible issue with `compute_contra_memobank_loss` #171

Open lucasdavid opened 2 months ago

lucasdavid commented 2 months ago

Hello! Thank you for the great paper and code. It's been really helpful to me! I believe there might be an issue with function compute_contra_memobank_loss. I'd appreciate it if you clarify it for me. The article states that:

"For i-th labeled image, a qualified negative sample for class c should be: (a) not belonging to class c; (b) difficult to distinguish between class c and its ground-truth category."

However, we see a different thing in compute_contra_memobank_loss:

high_valid_pixel = torch.cat((label_l, label_u), dim=0) * high_mask

for i in range(num_segments):
  high_valid_pixel_seg = high_valid_pixel[:, i]
  rep_mask_high_entropy = (prob_seg < 1.0) * high_valid_pixel_seg.bool()

  class_mask_u = torch.sum(
    prob_indices_u[:, :, :, low_rank:high_rank].eq(i), dim=3
  ).bool()
  class_mask_l = torch.sum(prob_indices_l[:, :, :, :low_rank].eq(i), dim=3).bool()

  class_mask = torch.cat((class_mask_l * (label_l[:, i] == 0), class_mask_u), dim=0)
  negative_mask = rep_mask_high_entropy * class_mask
  keys = rep_teacher[negative_mask].detach()
  new_keys.append(dequeue_and_enqueue(keys=keys, ...))

For the labeled samples, negative_mask is formed by the conjunction label_l[:, i] == 0 (from class_mask) and label_l[:, i] == 1 (from high_valid_pixel), so it will always be False:

negative_mask_l = rep_mask_high_entropy[:NL] * (class_mask_l * (label_l[:, i] == 0))
    = (label_l * high_mask[:NL])[:, i] * (class_mask_l * (label_l[:, i] == 0))
    = (label_l * (label_l[:, i] == 0) * high_mask[:NL])[:, i] * class_mask_l
    = (0 * high_mask[:NL])[:, i] * class_mask_l
    = 0

For unlabeled samples, negative_mask will be true if high_valid_pixel & class_mask_u (the teacher says the pixel label is i with high entropy (high_valid_pixel), and the student says it isn't with high entropy (prob_indices[..., low_rank:high_rank].eq(i)).

Therefore, we believe all labeled reps are being discarded. Does that make sense?

Cheers,