facebookresearch / PointContrast

Code for paper <PointContrast: Unsupervised Pretraining for 3D Point Cloud Understanding>
MIT License
328 stars 33 forks source link

Is the negative samples in PointInfoNCE across batch dimension? #6

Closed liu115 closed 3 years ago

liu115 commented 3 years ago

Thank you for the great work. I really love it.

I have questions about the PointInfoNCE implementation. Is the negative samples in PointInfoNCE across batch dimension?

The paper defined PointInfoNCE

In this formulation, we only consider points that have at least one match and do not use additional non-matched points as negatives. For a matched pair (i, j) ∈ P, point feature f^1_i will serve as the query and f^2_j will serve as the positive key k+. We use point feature f^2_k where ∃(·, k) ∈ P and k != j as the set of negative keys.

I wonder the all the pair in the matched pair set P here is the two views in the same scene or bigger than that. When I read the paper and pseudo code, I assume the batch dimension is ignored. For a point in x^1, the negative keys are all from the x^2. The P is down sample to 4096 point pairs for each mini-batch.

However, in the implementation ddp_trainer.PointNCELossTrainer, it seems like the negative keys are across the batch dimension. The negative samples for a point in x^1 may come from points in other scene. Am I correct?

    sinput0 = ME.SparseTensor(
        input_dict['sinput0_F'], coords=input_dict['sinput0_C']).to(self.cur_device)
    F0 = self.model(sinput0).F

    sinput1 = ME.SparseTensor(
        input_dict['sinput1_F'], coords=input_dict['sinput1_C']).to(self.cur_device)
    F1 = self.model(sinput1).F

    N0, N1 = input_dict['pcd0'].shape[0], input_dict['pcd1'].shape[0]
    pos_pairs = input_dict['correspondences'].to(self.cur_device)

    q_unique, count = pos_pairs[:, 0].unique(return_counts=True)
    uniform = torch.distributions.Uniform(0, 1).sample([len(count)]).to(self.cur_device)
    off = torch.floor(uniform*count).long()
    cums = torch.cat([torch.tensor([0], device=self.cur_device), torch.cumsum(count, dim=0)[0:-1]], dim=0)
    k_sel = pos_pairs[:, 1][off+cums]

    q = F0[q_unique.long()]
    k = F1[k_sel.long()]

    if self.npos < q.shape[0]:
        sampled_inds = np.random.choice(q.shape[0], self.npos, replace=False)
        q = q[sampled_inds]
        k = k[sampled_inds]
s9xie commented 3 years ago

Hi @liu115 - Thank you for your question. You are totally right, the negative keys are sampled across the entire mini-batch (more accurately, only the per GPU mini-batch for distributed training, since we are not doing all_gather operations across all GPUs when computing the loss).

Unfortunately this is not clearly noted in the paper when we are describing the formulation. However, we recently did additional experiments and tried pre-training with negatives sampled only from the scene that contains the positive. The results are very similar, at least on the S3DIS semantic segmentation task we tested.

Let me know if you have further questions.

liu115 commented 3 years ago

Thank you for the clarification.